Skip to content

Commit bffa34a

Browse files
committed
Normalize vectors on embedding creation instead of querying
- Normalizes only once instead of each time - Embedding creation takes time anyway, while query should be as fast as possible
1 parent 05c4f76 commit bffa34a

13 files changed

+105
-27
lines changed

‎collection.go‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,17 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
213213
m[k] = v
214214
}
215215

216-
// Create embedding if they don't exist
216+
// Create embedding if they don't exist, otherwise normalize if necessary
217217
if len(doc.Embedding) == 0 {
218218
embedding, err := c.embed(ctx, doc.Content)
219219
if err != nil {
220220
return fmt.Errorf("couldn't create embedding of document: %w", err)
221221
}
222222
doc.Embedding = embedding
223+
} else {
224+
if !isNormalized(doc.Embedding) {
225+
doc.Embedding = normalizeVector(doc.Embedding)
226+
}
223227
}
224228

225229
c.documentsLock.Lock()

‎collection_test.go‎

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) {
1313
ctx := context.Background()
1414
name := "test"
1515
metadata := map[string]string{"foo": "bar"}
16-
vectors := []float32{-0.1, 0.1, 0.2}
16+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
1717
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
1818
return vectors, nil
1919
}
@@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) {
109109
ctx := context.Background()
110110
name := "test"
111111
metadata := map[string]string{"foo": "bar"}
112-
vectors := []float32{-0.1, 0.1, 0.2}
112+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
113113
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
114114
return vectors, nil
115115
}
@@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) {
160160
ctx := context.Background()
161161
name := "test"
162162
metadata := map[string]string{"foo": "bar"}
163-
vectors := []float32{-0.1, 0.1, 0.2}
163+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
164164
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
165165
return vectors, nil
166166
}
@@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
256256
ctx := context.Background()
257257
name := "test"
258258
metadata := map[string]string{"foo": "bar"}
259-
vectors := []float32{-0.1, 0.1, 0.2}
259+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
260260
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
261261
return vectors, nil
262262
}
@@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) {
313313
db := NewDB()
314314
name := "test"
315315
metadata := map[string]string{"foo": "bar"}
316+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
316317
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
317-
return []float32{-0.1, 0.1, 0.2}, nil
318+
return vectors, nil
318319
}
319320
c, err := db.CreateCollection(name, metadata, embeddingFunc)
320321
if err != nil {

‎db.go‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ import (
1313
// EmbeddingFunc is a function that creates embeddings for a given text.
1414
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
1515
// but you can provide your own function, using any model you like.
16+
// The function must return a *normalized* vector, i.e. the length of the vector
17+
// must be 1. OpenAI's and Mistral's embedding models do this by default. Some
18+
// others like Nomic's "nomic-embed-text-v1.5" don't.
1619
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
1720

1821
// DB is the chromem-go database. It holds collections, which hold documents.

‎db_test.go‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) {
1010
// Values in the collection
1111
name := "test"
1212
metadata := map[string]string{"foo": "bar"}
13-
vectors := []float32{-0.1, 0.1, 0.2}
13+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
1414
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
1515
return vectors, nil
1616
}
@@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) {
8181
// Values in the collection
8282
name := "test"
8383
metadata := map[string]string{"foo": "bar"}
84-
vectors := []float32{-0.1, 0.1, 0.2}
84+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
8585
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
8686
return vectors, nil
8787
}
@@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) {
147147
// Values in the collection
148148
name := "test"
149149
metadata := map[string]string{"foo": "bar"}
150-
vectors := []float32{-0.1, 0.1, 0.2}
150+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
151151
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
152152
return vectors, nil
153153
}
@@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
196196
// Values in the collection
197197
name := "test"
198198
metadata := map[string]string{"foo": "bar"}
199-
vectors := []float32{-0.1, 0.1, 0.2}
199+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
200200
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
201201
return vectors, nil
202202
}
@@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) {
299299
// Values in the collection
300300
name := "test"
301301
metadata := map[string]string{"foo": "bar"}
302-
vectors := []float32{-0.1, 0.1, 0.2}
302+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
303303
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
304304
return vectors, nil
305305
}
@@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) {
331331
// Values in the collection
332332
name := "test"
333333
metadata := map[string]string{"foo": "bar"}
334-
vectors := []float32{-0.1, 0.1, 0.2}
334+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
335335
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
336336
return vectors, nil
337337
}

‎document_test.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) {
1010
ctx := context.Background()
1111
id := "test"
1212
metadata := map[string]string{"foo": "bar"}
13-
vectors := []float32{-0.1, 0.1, 0.2}
13+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
1414
content := "hello world"
1515
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
1616
return vectors, nil

‎embed_compat.go‎

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ const (
99
// NewEmbeddingFuncMistral returns a function that creates embeddings for a text
1010
// using the Mistral API.
1111
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
12+
// Mistral embeddings are normalized, see section "Distance Measures" on
13+
// https://docs.mistral.ai/guides/embeddings/.
14+
normalized := true
15+
1216
// The Mistral API docs don't mention the `encoding_format` as optional,
1317
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
14-
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral)
18+
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
1519
}
1620

1721
const baseURLJina = "https://api.jina.ai/v1"
@@ -28,7 +32,7 @@ const (
2832
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
2933
// using the Jina API.
3034
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
31-
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
35+
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
3236
}
3337

3438
const baseURLMixedbread = "https://api.mixedbread.ai"
@@ -49,7 +53,7 @@ const (
4953
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
5054
// using the mixedbread.ai API.
5155
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
52-
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
56+
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
5357
}
5458

5559
const baseURLLocalAI = "http://localhost:8080/v1"
@@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1"
6468
// But other embedding models are supported as well. See the LocalAI documentation
6569
// for details.
6670
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
67-
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model)
71+
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
6872
}

‎embed_ollama.go‎

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"net/http"
11+
"sync"
1112
)
1213

1314
// TODO: Turn into const and use as default, but allow user to pass custom URL
@@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
2829
// and it might have to be a long timeout, depending on the text length.
2930
client := &http.Client{}
3031

32+
var checkedNormalized bool
33+
checkNormalized := sync.Once{}
34+
3135
return func(ctx context.Context, text string) ([]float32, error) {
3236
// Prepare the request body.
3337
reqBody, err := json.Marshal(map[string]string{
@@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
7478
return nil, errors.New("no embeddings found in the response")
7579
}
7680

77-
return embeddingResponse.Embedding, nil
81+
v := embeddingResponse.Embedding
82+
checkNormalized.Do(func() {
83+
if isNormalized(v) {
84+
checkedNormalized = true
85+
} else {
86+
checkedNormalized = false
87+
}
88+
})
89+
if !checkedNormalized {
90+
v = normalizeVector(v)
91+
}
92+
93+
return v, nil
7894
}
7995
}

‎embed_ollama_test.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
2525
if err != nil {
2626
t.Fatal("unexpected error:", err)
2727
}
28-
wantRes := []float32{-0.1, 0.1, 0.2}
28+
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
2929

3030
// Mock server
3131
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

‎embed_openai.go‎

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"net/http"
1111
"os"
12+
"sync"
1213
)
1314

1415
const BaseURLOpenAI = "https://api.openai.com/v1"
@@ -39,7 +40,9 @@ func NewEmbeddingFuncDefault() EmbeddingFunc {
3940
// NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text
4041
// using the OpenAI API.
4142
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
42-
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model))
43+
// OpenAI embeddings are normalized
44+
normalized := true
45+
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
4346
}
4447

4548
// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
@@ -48,12 +51,20 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding
4851
// - LitLLM: https://github.com/BerriAI/litellm
4952
// - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md
5053
// - etc.
51-
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {
54+
//
55+
// The `normalized` parameter indicates whether the vectors returned by the embedding
56+
// model are already normalized, as is the case for OpenAI's and Mistral's models.
57+
// The flag is optional. If it's nil, it will be autodetected on the first request
58+
// (which bears a small risk that the vector just happens to have a length of 1).
59+
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
5260
// We don't set a default timeout here, although it's usually a good idea.
5361
// In our case though, the library user can set the timeout on the context,
5462
// and it might have to be a long timeout, depending on the text length.
5563
client := &http.Client{}
5664

65+
var checkedNormalized bool
66+
checkNormalized := sync.Once{}
67+
5768
return func(ctx context.Context, text string) ([]float32, error) {
5869
// Prepare the request body.
5970
reqBody, err := json.Marshal(map[string]string{
@@ -101,6 +112,24 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {
101112
return nil, errors.New("no embeddings found in the response")
102113
}
103114

104-
return embeddingResponse.Data[0].Embedding, nil
115+
v := embeddingResponse.Data[0].Embedding
116+
if normalized != nil {
117+
if *normalized {
118+
return v, nil
119+
}
120+
return normalizeVector(v), nil
121+
}
122+
checkNormalized.Do(func() {
123+
if isNormalized(v) {
124+
checkedNormalized = true
125+
} else {
126+
checkedNormalized = false
127+
}
128+
})
129+
if !checkedNormalized {
130+
v = normalizeVector(v)
131+
}
132+
133+
return v, nil
105134
}
106135
}

‎embed_openai_test.go‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
3333
if err != nil {
3434
t.Fatal("unexpected error:", err)
3535
}
36-
wantRes := []float32{-0.1, 0.1, 0.2}
36+
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
3737

3838
// Mock server
3939
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
7575
defer ts.Close()
7676
baseURL := ts.URL + baseURLSuffix
7777

78-
f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model)
78+
f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil)
7979
res, err := f(context.Background(), input)
8080
if err != nil {
8181
t.Fatal("expected nil, got", err)

0 commit comments

Comments
 (0)