Skip to content

Commit ed5dca6

Browse files
authored
Merge pull request #34 from philippgille/add-go-idiomatic-methods
Export Document struct and add Go-idiomatic methods for adding them to a collection
2 parents a2df4ad + cb0fe2f commit ed5dca6

File tree

10 files changed

+248
-180
lines changed

10 files changed

+248
-180
lines changed

‎collection.go‎

Lines changed: 151 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ type Collection struct {
1919

2020
persistDirectory string
2121
metadata map[string]string
22-
documents map[string]*document
22+
documents map[string]*Document
2323
documentsLock sync.RWMutex
2424
embed EmbeddingFunc
2525
}
@@ -38,7 +38,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
3838
Name: name,
3939

4040
metadata: m,
41-
documents: make(map[string]*document),
41+
documents: make(map[string]*Document),
4242
embed: embed,
4343
}
4444

@@ -73,24 +73,166 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
7373
//
7474
// - ids: The ids of the embeddings you wish to add
7575
// - embeddings: The embeddings to add. If nil, embeddings will be computed based
76-
// on the documents using the embeddingFunc set for the Collection. Optional.
76+
// on the contents using the embeddingFunc set for the Collection. Optional.
7777
// - metadatas: The metadata to associate with the embeddings. When querying,
7878
// you can filter on this metadata. Optional.
79-
// - documents: The documents to associate with the embeddings.
79+
// - contents: The contents to associate with the embeddings.
8080
//
81-
// A row-based API will be added when Chroma adds it (they already plan to).
82-
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error {
83-
return c.add(ctx, ids, documents, embeddings, metadatas, 1)
81+
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
82+
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
83+
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
8484
}
8585

8686
// AddConcurrently is like Add, but adds embeddings concurrently.
8787
// This is mostly useful when you don't pass any embeddings so they have to be created.
8888
// Upon error, concurrently running operations are canceled and the error is returned.
89-
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error {
89+
//
90+
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
91+
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
92+
if len(ids) == 0 {
93+
return errors.New("ids are empty")
94+
}
95+
if len(embeddings) == 0 && len(contents) == 0 {
96+
return errors.New("either embeddings or contents must be filled")
97+
}
98+
if len(embeddings) != 0 {
99+
if len(embeddings) != len(ids) {
100+
return errors.New("ids and embeddings must have the same length")
101+
}
102+
} else {
103+
// Assign empty slice so we can simply access via index later
104+
embeddings = make([][]float32, len(ids))
105+
}
106+
if len(metadatas) != 0 && len(ids) != len(metadatas) {
107+
return errors.New("ids, metadatas and contents must have the same length")
108+
}
109+
if len(contents) != 0 {
110+
if len(contents) != len(ids) {
111+
return errors.New("ids and contents must have the same length")
112+
}
113+
} else {
114+
// Assign empty slice so we can simply access via index later
115+
contents = make([]string, len(ids))
116+
}
90117
if concurrency < 1 {
91118
return errors.New("concurrency must be at least 1")
92119
}
93-
return c.add(ctx, ids, documents, embeddings, metadatas, concurrency)
120+
121+
// Convert Chroma-style parameters into a slice of documents.
122+
docs := make([]Document, 0, len(ids))
123+
for i, id := range ids {
124+
docs = append(docs, Document{
125+
ID: id,
126+
Metadata: metadatas[i],
127+
Embedding: embeddings[i],
128+
Content: contents[i],
129+
})
130+
}
131+
132+
return c.AddDocuments(ctx, docs, concurrency)
133+
}
134+
135+
// AddDocuments adds documents to the collection with the specified concurrency.
136+
// If the documents don't have embeddings, they will be created using the collection's
137+
// embedding function.
138+
// Upon error, concurrently running operations are canceled and the error is returned.
139+
func (c *Collection) AddDocuments(ctx context.Context, documents []Document, concurrency int) error {
140+
if len(documents) == 0 {
141+
// TODO: Should this be a no-op instead?
142+
return errors.New("documents slice is nil or empty")
143+
}
144+
if concurrency < 1 {
145+
return errors.New("concurrency must be at least 1")
146+
}
147+
// For other validations we rely on AddDocument.
148+
149+
var globalErr error
150+
globalErrLock := sync.Mutex{}
151+
ctx, cancel := context.WithCancelCause(ctx)
152+
defer cancel(nil)
153+
setGlobalErr := func(err error) {
154+
globalErrLock.Lock()
155+
defer globalErrLock.Unlock()
156+
// Another goroutine might have already set the error.
157+
if globalErr == nil {
158+
globalErr = err
159+
// Cancel the operation for all other goroutines.
160+
cancel(globalErr)
161+
}
162+
}
163+
164+
var wg sync.WaitGroup
165+
semaphore := make(chan struct{}, concurrency)
166+
for _, doc := range documents {
167+
wg.Add(1)
168+
go func(doc Document) {
169+
defer wg.Done()
170+
171+
// Don't even start if another goroutine already failed.
172+
if ctx.Err() != nil {
173+
return
174+
}
175+
176+
// Wait here while $concurrency other goroutines are creating documents.
177+
semaphore <- struct{}{}
178+
defer func() { <-semaphore }()
179+
180+
err := c.AddDocument(ctx, doc)
181+
if err != nil {
182+
setGlobalErr(fmt.Errorf("couldn't add document '%s': %w", doc.ID, err))
183+
return
184+
}
185+
}(doc)
186+
}
187+
188+
wg.Wait()
189+
190+
return globalErr
191+
}
192+
193+
// AddDocument adds a document to the collection.
194+
// If the document doesn't have an embedding, it will be created using the collection's
195+
// embedding function.
196+
func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
197+
if doc.ID == "" {
198+
return errors.New("document ID is empty")
199+
}
200+
if len(doc.Embedding) == 0 && doc.Content == "" {
201+
return errors.New("either document embedding or content must be filled")
202+
}
203+
204+
// We copy the metadata to avoid data races in case the caller modifies the
205+
// map after creating the document while we range over it.
206+
m := make(map[string]string, len(doc.Metadata))
207+
for k, v := range doc.Metadata {
208+
m[k] = v
209+
}
210+
211+
// Create embedding if they don't exist
212+
if len(doc.Embedding) == 0 {
213+
embedding, err := c.embed(ctx, doc.Content)
214+
if err != nil {
215+
return fmt.Errorf("couldn't create embedding of document: %w", err)
216+
}
217+
doc.Embedding = embedding
218+
}
219+
220+
c.documentsLock.Lock()
221+
// We don't defer the unlock because we want to do it earlier.
222+
c.documents[doc.ID] = &doc
223+
c.documentsLock.Unlock()
224+
225+
// Persist the document
226+
if c.persistDirectory != "" {
227+
safeID := hash2hex(doc.ID)
228+
filePath := path.Join(c.persistDirectory, safeID)
229+
err := persist(filePath, doc)
230+
if err != nil {
231+
return fmt.Errorf("couldn't persist document: %w", err)
232+
}
233+
}
234+
235+
return nil
94236
}
95237

96238
// Count returns the number of documents in the collection.
@@ -155,91 +297,3 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
155297
// Return the top nResults
156298
return res[:nResults], nil
157299
}
158-
159-
func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error {
160-
if len(ids) == 0 || len(documents) == 0 {
161-
return errors.New("ids and documents must not be empty")
162-
}
163-
if len(ids) != len(documents) {
164-
return errors.New("ids and documents must have the same length")
165-
}
166-
if len(embeddings) != 0 && len(ids) != len(embeddings) {
167-
return errors.New("ids, embeddings and documents must have the same length")
168-
}
169-
if len(metadatas) != 0 && len(ids) != len(metadatas) {
170-
return errors.New("ids, metadatas and documents must have the same length")
171-
}
172-
173-
ctx, cancel := context.WithCancelCause(ctx)
174-
defer cancel(nil)
175-
176-
var wg sync.WaitGroup
177-
var globalErr error
178-
var globalErrLock sync.Mutex
179-
semaphore := make(chan struct{}, concurrency)
180-
for i, document := range documents {
181-
var embedding []float32
182-
var metadata map[string]string
183-
if len(embeddings) != 0 {
184-
embedding = embeddings[i]
185-
}
186-
if len(metadatas) != 0 {
187-
metadata = metadatas[i]
188-
}
189-
190-
wg.Add(1)
191-
go func(id string, embedding []float32, metadata map[string]string, document string) {
192-
defer wg.Done()
193-
194-
// Don't even start if we already have an error
195-
if ctx.Err() != nil {
196-
return
197-
}
198-
199-
// Wait here while $concurrency other goroutines are creating documents.
200-
semaphore <- struct{}{}
201-
defer func() { <-semaphore }()
202-
203-
err := c.addRow(ctx, id, document, embedding, metadata)
204-
if err != nil {
205-
globalErrLock.Lock()
206-
defer globalErrLock.Unlock()
207-
// Another goroutine might have already set the error.
208-
if globalErr == nil {
209-
globalErr = err
210-
// Cancel the operation for all other goroutines.
211-
cancel(globalErr)
212-
}
213-
return
214-
}
215-
}(ids[i], embedding, metadata, document)
216-
}
217-
218-
wg.Wait()
219-
220-
return globalErr
221-
}
222-
223-
func (c *Collection) addRow(ctx context.Context, id string, document string, embedding []float32, metadata map[string]string) error {
224-
doc, err := newDocument(ctx, id, embedding, metadata, document, c.embed)
225-
if err != nil {
226-
return fmt.Errorf("couldn't create document '%s': %w", id, err)
227-
}
228-
229-
c.documentsLock.Lock()
230-
// We don't defer the unlock because we want to do it earlier.
231-
c.documents[id] = doc
232-
c.documentsLock.Unlock()
233-
234-
// Persist the document
235-
if c.persistDirectory != "" {
236-
safeID := hash2hex(id)
237-
filePath := path.Join(c.persistDirectory, safeID)
238-
err := persist(filePath, doc)
239-
if err != nil {
240-
return fmt.Errorf("couldn't persist document: %w", err)
241-
}
242-
}
243-
244-
return nil
245-
}

‎collection_test.go‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ func TestCollection_Add(t *testing.T) {
2626
// Add document
2727
ids := []string{"1", "2"}
2828
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
29-
documents := []string{"hello world", "hallo welt"}
30-
err = c.Add(context.Background(), ids, nil, metadatas, documents)
29+
contents := []string{"hello world", "hallo welt"}
30+
err = c.Add(context.Background(), ids, nil, metadatas, contents)
3131
if err != nil {
3232
t.Error("expected nil, got", err)
3333
}
@@ -54,8 +54,8 @@ func TestCollection_Count(t *testing.T) {
5454
// Add documents
5555
ids := []string{"1", "2"}
5656
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
57-
documents := []string{"hello world", "hallo welt"}
58-
err = c.Add(context.Background(), ids, nil, metadatas, documents)
57+
contents := []string{"hello world", "hallo welt"}
58+
err = c.Add(context.Background(), ids, nil, metadatas, contents)
5959
if err != nil {
6060
t.Error("expected nil, got", err)
6161
}

‎db.go‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ import (
1010
"sync"
1111
)
1212

13-
// EmbeddingFunc is a function that creates embeddings for a given document.
13+
// EmbeddingFunc is a function that creates embeddings for a given text.
1414
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
1515
// but you can provide your own function, using any model you like.
16-
type EmbeddingFunc func(ctx context.Context, document string) ([]float32, error)
16+
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)
1717

1818
// DB is the chromem-go database. It holds collections, which hold documents.
1919
//
@@ -91,7 +91,7 @@ func NewPersistentDB(path string) (*DB, error) {
9191
c := &Collection{
9292
// We can fill Name, persistDirectory and metadata only after reading
9393
// the metadata.
94-
documents: make(map[string]*document),
94+
documents: make(map[string]*Document),
9595
// We can fill embed only when the user calls DB.GetCollection() or
9696
// DB.GetOrCreateCollection().
9797
}
@@ -119,7 +119,7 @@ func NewPersistentDB(path string) (*DB, error) {
119119
c.metadata = pc.Metadata
120120
} else if filepath.Ext(collectionDirEntry.Name()) == ".gob" {
121121
// Read document
122-
d := &document{}
122+
d := &Document{}
123123
err := read(fPath, d)
124124
if err != nil {
125125
return nil, fmt.Errorf("couldn't read document: %w", err)

0 commit comments

Comments
 (0)