Skip to content

Commit ff28a38

Browse files
committed
Add "normalized" parameter to skip check if normalization is known
1 parent 503c3ce commit ff28a38

File tree

9 files changed

+62
-34
lines changed

9 files changed

+62
-34
lines changed

‎README.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func main() {
9898
db := chromem.NewDB()
9999

100100
// Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available!
101-
collection, _ := db.CreateCollection("all-my-documents", nil, nil)
101+
collection, _ := db.CreateCollection("all-my-documents", nil, nil, nil)
102102

103103
// Add docs to the collection. Update and delete will be added in the future.
104104
// Can be multi-threaded with AddConcurrently()!

‎collection.go‎

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ type Collection struct {
2222
documents map[string]*Document
2323
documentsLock sync.RWMutex
2424
embed EmbeddingFunc
25+
normalized *bool
2526
}
2627

2728
// We don't export this yet to keep the API surface to the bare minimum.
2829
// Users create collections via [Client.CreateCollection].
29-
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) {
30+
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, normalized *bool, dir string) (*Collection, error) {
3031
// We copy the metadata to avoid data races in case the caller modifies the
3132
// map after creating the collection while we range over it.
3233
m := make(map[string]string, len(metadata))
@@ -37,9 +38,10 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
3738
c := &Collection{
3839
Name: name,
3940

40-
metadata: m,
41-
documents: make(map[string]*Document),
42-
embed: embed,
41+
metadata: m,
42+
documents: make(map[string]*Document),
43+
embed: embed,
44+
normalized: normalized,
4345
}
4446

4547
// Persistence
@@ -301,7 +303,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
301303
}
302304

303305
// For the remaining documents, calculate cosine similarity.
304-
docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
306+
docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs, c.normalized)
305307
if err != nil {
306308
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
307309
}

‎collection_test.go‎

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func TestCollection_Add(t *testing.T) {
2020

2121
// Create collection
2222
db := NewDB()
23-
c, err := db.CreateCollection(name, metadata, embeddingFunc)
23+
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
2424
if err != nil {
2525
t.Fatal("expected no error, got", err)
2626
}
@@ -116,7 +116,7 @@ func TestCollection_Add_Error(t *testing.T) {
116116

117117
// Create collection
118118
db := NewDB()
119-
c, err := db.CreateCollection(name, metadata, embeddingFunc)
119+
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
120120
if err != nil {
121121
t.Fatal("expected no error, got", err)
122122
}
@@ -167,7 +167,7 @@ func TestCollection_AddConcurrently(t *testing.T) {
167167

168168
// Create collection
169169
db := NewDB()
170-
c, err := db.CreateCollection(name, metadata, embeddingFunc)
170+
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
171171
if err != nil {
172172
t.Fatal("expected no error, got", err)
173173
}
@@ -263,7 +263,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
263263

264264
// Create collection
265265
db := NewDB()
266-
c, err := db.CreateCollection(name, metadata, embeddingFunc)
266+
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
267267
if err != nil {
268268
t.Fatal("expected no error, got", err)
269269
}
@@ -316,7 +316,7 @@ func TestCollection_Count(t *testing.T) {
316316
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
317317
return []float32{-0.1, 0.1, 0.2}, nil
318318
}
319-
c, err := db.CreateCollection(name, metadata, embeddingFunc)
319+
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
320320
if err != nil {
321321
t.Fatal("expected no error, got", err)
322322
}
@@ -407,7 +407,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
407407
// Create collection
408408
db := NewDB()
409409
name := "test"
410-
c, err := db.CreateCollection(name, nil, embeddingFunc)
410+
c, err := db.CreateCollection(name, nil, embeddingFunc, &trueVal)
411411
if err != nil {
412412
b.Fatal("expected no error, got", err)
413413
}

‎db.go‎

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,17 @@ func NewPersistentDB(path string) (*DB, error) {
142142
// - metadata: Optional metadata to associate with the collection.
143143
// - embeddingFunc: Optional function to use to embed documents.
144144
// Uses the default embedding function if not provided.
145-
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
145+
// - normalized: Optional flag to indicate if the embeddings of the collection
146+
// are normalized (when you add embeddings yourself, or the embeddings created
147+
// by the embeddingFunc). If nil it will be autodetected.
148+
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) {
146149
if name == "" {
147150
return nil, errors.New("collection name is empty")
148151
}
149152
if embeddingFunc == nil {
150153
embeddingFunc = NewEmbeddingFuncDefault()
151154
}
152-
collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory)
155+
collection, err := newCollection(name, metadata, embeddingFunc, normalized, db.persistDirectory)
153156
if err != nil {
154157
return nil, fmt.Errorf("couldn't create collection: %w", err)
155158
}
@@ -213,12 +216,15 @@ func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collectio
213216
// - metadata: Optional metadata to associate with the collection.
214217
// - embeddingFunc: Optional function to use to embed documents.
215218
// Uses the default embedding function if not provided.
216-
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
219+
// - normalized: Optional flag to indicate if the embeddings of the collection
220+
// are normalized (when you add embeddings yourself, or the embeddings created
221+
// by the embeddingFunc). If nil it will be autodetected.
222+
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) {
217223
// No need to lock here, because the methods we call do that.
218224
collection := db.GetCollection(name, embeddingFunc)
219225
if collection == nil {
220226
var err error
221-
collection, err = db.CreateCollection(name, metadata, embeddingFunc)
227+
collection, err = db.CreateCollection(name, metadata, embeddingFunc, normalized)
222228
if err != nil {
223229
return nil, fmt.Errorf("couldn't create collection: %w", err)
224230
}

‎db_test.go‎

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func TestDB_CreateCollection(t *testing.T) {
1818
db := NewDB()
1919

2020
t.Run("OK", func(t *testing.T) {
21-
c, err := db.CreateCollection(name, metadata, embeddingFunc)
21+
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
2222
if err != nil {
2323
t.Fatal("expected no error, got", err)
2424
}
@@ -70,7 +70,7 @@ func TestDB_CreateCollection(t *testing.T) {
7070
})
7171

7272
t.Run("NOK - Empty name", func(t *testing.T) {
73-
_, err := db.CreateCollection("", metadata, embeddingFunc)
73+
_, err := db.CreateCollection("", metadata, embeddingFunc, nil)
7474
if err == nil {
7575
t.Fatal("expected error, got nil")
7676
}
@@ -89,7 +89,7 @@ func TestDB_ListCollections(t *testing.T) {
8989
// Create initial collection
9090
db := NewDB()
9191
// We ignore the return value. CreateCollection is tested elsewhere.
92-
_, err := db.CreateCollection(name, metadata, embeddingFunc)
92+
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
9393
if err != nil {
9494
t.Fatal("expected no error, got", err)
9595
}
@@ -155,7 +155,7 @@ func TestDB_GetCollection(t *testing.T) {
155155
// Create initial collection
156156
db := NewDB()
157157
// We ignore the return value. CreateCollection is tested elsewhere.
158-
_, err := db.CreateCollection(name, metadata, embeddingFunc)
158+
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
159159
if err != nil {
160160
t.Fatal("expected no error, got", err)
161161
}
@@ -207,15 +207,15 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
207207
// Create collection so that the GetOrCreateCollection() call below only
208208
// gets it.
209209
// We ignore the return value. CreateCollection is tested elsewhere.
210-
_, err := db.CreateCollection(name, metadata, embeddingFunc)
210+
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
211211
if err != nil {
212212
t.Fatal("expected no error, got", err)
213213
}
214214

215215
// Call GetOrCreateCollection() with the same name to only get it. We pass
216216
// nil for the metadata and embeddingFunc so we can check that the returned
217217
// collection is the original one, and not a new one.
218-
c, err := db.GetOrCreateCollection(name, nil, nil)
218+
c, err := db.GetOrCreateCollection(name, nil, embeddingFunc, nil)
219219
if err != nil {
220220
t.Fatal("expected no error, got", err)
221221
}
@@ -257,7 +257,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
257257
db := NewDB()
258258

259259
// Call GetOrCreateCollection()
260-
c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc)
260+
c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc, nil)
261261
if err != nil {
262262
t.Fatal("expected no error, got", err)
263263
}
@@ -307,7 +307,7 @@ func TestDB_DeleteCollection(t *testing.T) {
307307
// Create initial collection
308308
db := NewDB()
309309
// We ignore the return value. CreateCollection is tested elsewhere.
310-
_, err := db.CreateCollection(name, metadata, embeddingFunc)
310+
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
311311
if err != nil {
312312
t.Fatal("expected no error, got", err)
313313
}
@@ -339,7 +339,7 @@ func TestDB_Reset(t *testing.T) {
339339
// Create initial collection
340340
db := NewDB()
341341
// We ignore the return value. CreateCollection is tested elsewhere.
342-
_, err := db.CreateCollection(name, metadata, embeddingFunc)
342+
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
343343
if err != nil {
344344
t.Fatal("expected no error, got", err)
345345
}

‎examples/rag-wikipedia-ollama/main.go‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ const (
1919
embeddingModel = "nomic-embed-text"
2020
)
2121

22+
// The nomic-embed-text-v1.5 model doesn't return normalized embeddings
23+
var normalized = false
24+
2225
func main() {
2326
ctx := context.Background()
2427

@@ -49,7 +52,7 @@ func main() {
4952
// variable to be set.
5053
// For this example we choose to use a locally running embedding model though.
5154
// It requires Ollama to serve its API at "http://localhost:11434/api".
52-
collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel))
55+
collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel), &normalized)
5356
if err != nil {
5457
panic(err)
5558
}

‎examples/semantic-search-arxiv-openai/main.go‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ import (
1616

1717
const searchTerm = "semantic search with vector databases"
1818

19+
// OpenAI embeddings are already normalized.
20+
var normalized = true
21+
1922
func main() {
2023
ctx := context.Background()
2124

@@ -30,7 +33,7 @@ func main() {
3033
// We pass nil as embedding function to use the default (OpenAI text-embedding-3-small),
3134
// which is very good and cheap. It requires the OPENAI_API_KEY environment
3235
// variable to be set.
33-
collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil)
36+
collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil, &normalized)
3437
if err != nil {
3538
panic(err)
3639
}

‎query.go‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string]
9595
return true
9696
}
9797

98-
func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) {
98+
func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document, isNormalized *bool) ([]docSim, error) {
9999
similarities := make([]docSim, 0, len(docs))
100100
similaritiesLock := sync.Mutex{}
101101

@@ -145,7 +145,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
145145
return
146146
}
147147

148-
sim, err := cosineSimilarity(queryVectors, doc.Embedding)
148+
sim, err := cosineSimilarity(queryVectors, doc.Embedding, isNormalized)
149149
if err != nil {
150150
setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
151151
return

‎vector.go‎

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,33 @@ import (
77

88
const isNormalizedPrecisionTolerance = 1e-6
99

10+
var (
11+
falseVal = false
12+
trueVal = true
13+
)
14+
1015
// cosineSimilarity calculates the cosine similarity between two vectors.
11-
// Vectors are normalized first.
16+
// Pass isNormalized=true if the vectors are already normalized, false
17+
// to normalize them, and nil to autodetect.
1218
// The resulting value represents the similarity, so a higher value means the
1319
// vectors are more similar.
14-
func cosineSimilarity(a, b []float32) (float32, error) {
20+
func cosineSimilarity(a, b []float32, isNormalized *bool) (float32, error) {
1521
// The vectors must have the same length
1622
if len(a) != len(b) {
1723
return 0, errors.New("vectors must have the same length")
1824
}
1925

20-
if !isNormalized(a) || !isNormalized(b) {
26+
if isNormalized == nil {
27+
if !checkNormalized(a) || !checkNormalized(b) {
28+
isNormalized = &falseVal
29+
} else {
30+
isNormalized = &trueVal
31+
}
32+
}
33+
if !*isNormalized {
2134
a, b = normalizeVector(a), normalizeVector(b)
2235
}
36+
2337
var dotProduct float32
2438
for i := range a {
2539
dotProduct += a[i] * b[i]
@@ -44,8 +58,8 @@ func normalizeVector(v []float32) []float32 {
4458
return res
4559
}
4660

47-
// isNormalized checks if the vector is normalized.
48-
func isNormalized(v []float32) bool {
61+
// checkNormalized checks if the vector is normalized.
62+
func checkNormalized(v []float32) bool {
4963
var sqSum float64
5064
for _, val := range v {
5165
sqSum += float64(val) * float64(val)

0 commit comments

Comments
 (0)