@@ -170,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection {
170170}
171171
172172// GetCollection returns the collection with the given name.
173- // The returned value is a reference to the original collection, so any methods
173+ // The embeddingFunc param is only used if the DB is persistent and was just loaded
174+ // from storage, in which case no embedding func is set yet (funcs are not (de-)serializable).
175+ // It can be nil, in which case the default one will be used.
176+ // The returned collection is a reference to the original collection, so any methods
174177// on the collection like Add() will be reflected on the DB's collection. Those
175178// operations are concurrency-safe.
176179// If the collection doesn't exist, this returns nil.
177- func (db * DB ) GetCollection (name string ) * Collection {
180+ func (db * DB ) GetCollection (name string , embeddingFunc EmbeddingFunc ) * Collection {
178181 db .collectionsLock .RLock ()
179182 defer db .collectionsLock .RUnlock ()
180- return db .collections [name ]
183+
184+ c , ok := db .collections [name ]
185+ if ! ok {
186+ return nil
187+ }
188+
189+ if c .embed == nil {
190+ if embeddingFunc == nil {
191+ c .embed = NewEmbeddingFuncDefault ()
192+ } else {
193+ c .embed = embeddingFunc
194+ }
195+ }
196+ return c
181197}
182198
183199// GetOrCreateCollection returns the collection with the given name if it exists
@@ -189,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection {
189205// Uses the default embedding function if not provided.
190206func (db * DB ) GetOrCreateCollection (name string , metadata map [string ]string , embeddingFunc EmbeddingFunc ) (* Collection , error ) {
191207 // No need to lock here, because the methods we call do that.
192- collection := db .GetCollection (name )
208+ collection := db .GetCollection (name , embeddingFunc )
193209 if collection == nil {
194210 var err error
195211 collection , err = db .CreateCollection (name , metadata , embeddingFunc )
0 commit comments