@@ -19,7 +19,7 @@ type Collection struct {
1919
2020 persistDirectory string
2121 metadata map [string ]string
22- documents map [string ]* document
22+ documents map [string ]* Document
2323 documentsLock sync.RWMutex
2424 embed EmbeddingFunc
2525}
@@ -38,7 +38,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
3838 Name : name ,
3939
4040 metadata : m ,
41- documents : make (map [string ]* document ),
41+ documents : make (map [string ]* Document ),
4242 embed : embed ,
4343 }
4444
@@ -73,24 +73,166 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
7373//
7474// - ids: The ids of the embeddings you wish to add
7575// - embeddings: The embeddings to add. If nil, embeddings will be computed based
76- // on the documents using the embeddingFunc set for the Collection. Optional.
76+ // on the contents using the embeddingFunc set for the Collection. Optional.
7777// - metadatas: The metadata to associate with the embeddings. When querying,
7878// you can filter on this metadata. Optional.
79- // - documents : The documents to associate with the embeddings.
79+ // - contents : The contents to associate with the embeddings.
8080//
81- // A row-based API will be added when Chroma adds it (they already plan to) .
82- func (c * Collection ) Add (ctx context.Context , ids []string , embeddings [][]float32 , metadatas []map [string ]string , documents []string ) error {
83- return c .add (ctx , ids , documents , embeddings , metadatas , 1 )
81+ // This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments] .
82+ func (c * Collection ) Add (ctx context.Context , ids []string , embeddings [][]float32 , metadatas []map [string ]string , contents []string ) error {
83+ return c .AddConcurrently (ctx , ids , embeddings , metadatas , contents , 1 )
8484}
8585
8686// AddConcurrently is like Add, but adds embeddings concurrently.
8787// This is mostly useful when you don't pass any embeddings so they have to be created.
8888// Upon error, concurrently running operations are canceled and the error is returned.
89- func (c * Collection ) AddConcurrently (ctx context.Context , ids []string , embeddings [][]float32 , metadatas []map [string ]string , documents []string , concurrency int ) error {
89+ //
90+ // This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
91+ func (c * Collection ) AddConcurrently (ctx context.Context , ids []string , embeddings [][]float32 , metadatas []map [string ]string , contents []string , concurrency int ) error {
92+ if len (ids ) == 0 {
93+ return errors .New ("ids are empty" )
94+ }
95+ if len (embeddings ) == 0 && len (contents ) == 0 {
96+ return errors .New ("either embeddings or contents must be filled" )
97+ }
98+ if len (embeddings ) != 0 {
99+ if len (embeddings ) != len (ids ) {
100+ return errors .New ("ids and embeddings must have the same length" )
101+ }
102+ } else {
103+ // Assign empty slice so we can simply access via index later
104+ embeddings = make ([][]float32 , len (ids ))
105+ }
106+ if len (metadatas ) != 0 && len (ids ) != len (metadatas ) {
107+ return errors .New ("ids, metadatas and contents must have the same length" )
108+ }
109+ if len (contents ) != 0 {
110+ if len (contents ) != len (ids ) {
111+ return errors .New ("ids and contents must have the same length" )
112+ }
113+ } else {
114+ // Assign empty slice so we can simply access via index later
115+ contents = make ([]string , len (ids ))
116+ }
90117 if concurrency < 1 {
91118 return errors .New ("concurrency must be at least 1" )
92119 }
93- return c .add (ctx , ids , documents , embeddings , metadatas , concurrency )
120+
121+ // Convert Chroma-style parameters into a slice of documents.
122+ docs := make ([]Document , 0 , len (ids ))
123+ for i , id := range ids {
124+ docs = append (docs , Document {
125+ ID : id ,
126+ Metadata : metadatas [i ],
127+ Embedding : embeddings [i ],
128+ Content : contents [i ],
129+ })
130+ }
131+
132+ return c .AddDocuments (ctx , docs , concurrency )
133+ }
134+
135+ // AddDocuments adds documents to the collection with the specified concurrency.
136+ // If the documents don't have embeddings, they will be created using the collection's
137+ // embedding function.
138+ // Upon error, concurrently running operations are canceled and the error is returned.
139+ func (c * Collection ) AddDocuments (ctx context.Context , documents []Document , concurrency int ) error {
140+ if len (documents ) == 0 {
141+ // TODO: Should this be a no-op instead?
142+ return errors .New ("documents slice is nil or empty" )
143+ }
144+ if concurrency < 1 {
145+ return errors .New ("concurrency must be at least 1" )
146+ }
147+ // For other validations we rely on AddDocument.
148+
149+ var globalErr error
150+ globalErrLock := sync.Mutex {}
151+ ctx , cancel := context .WithCancelCause (ctx )
152+ defer cancel (nil )
153+ setGlobalErr := func (err error ) {
154+ globalErrLock .Lock ()
155+ defer globalErrLock .Unlock ()
156+ // Another goroutine might have already set the error.
157+ if globalErr == nil {
158+ globalErr = err
159+ // Cancel the operation for all other goroutines.
160+ cancel (globalErr )
161+ }
162+ }
163+
164+ var wg sync.WaitGroup
165+ semaphore := make (chan struct {}, concurrency )
166+ for _ , doc := range documents {
167+ wg .Add (1 )
168+ go func (doc Document ) {
169+ defer wg .Done ()
170+
171+ // Don't even start if another goroutine already failed.
172+ if ctx .Err () != nil {
173+ return
174+ }
175+
176+ // Wait here while $concurrency other goroutines are creating documents.
177+ semaphore <- struct {}{}
178+ defer func () { <- semaphore }()
179+
180+ err := c .AddDocument (ctx , doc )
181+ if err != nil {
182+ setGlobalErr (fmt .Errorf ("couldn't add document '%s': %w" , doc .ID , err ))
183+ return
184+ }
185+ }(doc )
186+ }
187+
188+ wg .Wait ()
189+
190+ return globalErr
191+ }
192+
193+ // AddDocument adds a document to the collection.
194+ // If the document doesn't have an embedding, it will be created using the collection's
195+ // embedding function.
196+ func (c * Collection ) AddDocument (ctx context.Context , doc Document ) error {
197+ if doc .ID == "" {
198+ return errors .New ("document ID is empty" )
199+ }
200+ if len (doc .Embedding ) == 0 && doc .Content == "" {
201+ return errors .New ("either document embedding or content must be filled" )
202+ }
203+
204+ // We copy the metadata to avoid data races in case the caller modifies the
205+ // map after creating the document while we range over it.
206+ m := make (map [string ]string , len (doc .Metadata ))
207+ for k , v := range doc .Metadata {
208+ m [k ] = v
209+ }
210+
211+ // Create embedding if they don't exist
212+ if len (doc .Embedding ) == 0 {
213+ embedding , err := c .embed (ctx , doc .Content )
214+ if err != nil {
215+ return fmt .Errorf ("couldn't create embedding of document: %w" , err )
216+ }
217+ doc .Embedding = embedding
218+ }
219+
220+ c .documentsLock .Lock ()
221+ // We don't defer the unlock because we want to do it earlier.
222+ c .documents [doc .ID ] = & doc
223+ c .documentsLock .Unlock ()
224+
225+ // Persist the document
226+ if c .persistDirectory != "" {
227+ safeID := hash2hex (doc .ID )
228+ filePath := path .Join (c .persistDirectory , safeID )
229+ err := persist (filePath , doc )
230+ if err != nil {
231+ return fmt .Errorf ("couldn't persist document: %w" , err )
232+ }
233+ }
234+
235+ return nil
94236}
95237
96238// Count returns the number of documents in the collection.
@@ -155,91 +297,3 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
155297 // Return the top nResults
156298 return res [:nResults ], nil
157299}
158-
159- func (c * Collection ) add (ctx context.Context , ids []string , documents []string , embeddings [][]float32 , metadatas []map [string ]string , concurrency int ) error {
160- if len (ids ) == 0 || len (documents ) == 0 {
161- return errors .New ("ids and documents must not be empty" )
162- }
163- if len (ids ) != len (documents ) {
164- return errors .New ("ids and documents must have the same length" )
165- }
166- if len (embeddings ) != 0 && len (ids ) != len (embeddings ) {
167- return errors .New ("ids, embeddings and documents must have the same length" )
168- }
169- if len (metadatas ) != 0 && len (ids ) != len (metadatas ) {
170- return errors .New ("ids, metadatas and documents must have the same length" )
171- }
172-
173- ctx , cancel := context .WithCancelCause (ctx )
174- defer cancel (nil )
175-
176- var wg sync.WaitGroup
177- var globalErr error
178- var globalErrLock sync.Mutex
179- semaphore := make (chan struct {}, concurrency )
180- for i , document := range documents {
181- var embedding []float32
182- var metadata map [string ]string
183- if len (embeddings ) != 0 {
184- embedding = embeddings [i ]
185- }
186- if len (metadatas ) != 0 {
187- metadata = metadatas [i ]
188- }
189-
190- wg .Add (1 )
191- go func (id string , embedding []float32 , metadata map [string ]string , document string ) {
192- defer wg .Done ()
193-
194- // Don't even start if we already have an error
195- if ctx .Err () != nil {
196- return
197- }
198-
199- // Wait here while $concurrency other goroutines are creating documents.
200- semaphore <- struct {}{}
201- defer func () { <- semaphore }()
202-
203- err := c .addRow (ctx , id , document , embedding , metadata )
204- if err != nil {
205- globalErrLock .Lock ()
206- defer globalErrLock .Unlock ()
207- // Another goroutine might have already set the error.
208- if globalErr == nil {
209- globalErr = err
210- // Cancel the operation for all other goroutines.
211- cancel (globalErr )
212- }
213- return
214- }
215- }(ids [i ], embedding , metadata , document )
216- }
217-
218- wg .Wait ()
219-
220- return globalErr
221- }
222-
223- func (c * Collection ) addRow (ctx context.Context , id string , document string , embedding []float32 , metadata map [string ]string ) error {
224- doc , err := newDocument (ctx , id , embedding , metadata , document , c .embed )
225- if err != nil {
226- return fmt .Errorf ("couldn't create document '%s': %w" , id , err )
227- }
228-
229- c .documentsLock .Lock ()
230- // We don't defer the unlock because we want to do it earlier.
231- c .documents [id ] = doc
232- c .documentsLock .Unlock ()
233-
234- // Persist the document
235- if c .persistDirectory != "" {
236- safeID := hash2hex (id )
237- filePath := path .Join (c .persistDirectory , safeID )
238- err := persist (filePath , doc )
239- if err != nil {
240- return fmt .Errorf ("couldn't persist document: %w" , err )
241- }
242- }
243-
244- return nil
245- }
0 commit comments