Skip to content

Commit 727a5f1

Browse files
committed
Add EmbeddingFunc param to DB.GetCollection()
Required for when the DB was just loaded from persistant storage, as funcs can't be (de-)serialized.
1 parent 407db96 commit 727a5f1

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

‎db.go‎

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection {
170170
}
171171

172172
// GetCollection returns the collection with the given name.
173-
// The returned value is a reference to the original collection, so any methods
173+
// The embeddingFunc param is only used if the DB is persistent and was just loaded
174+
// from storage, in which case no embedding func is set yet (funcs are not (de-)serializable).
175+
// It can be nil, in which case the default one will be used.
176+
// The returned collection is a reference to the original collection, so any methods
174177
// on the collection like Add() will be reflected on the DB's collection. Those
175178
// operations are concurrency-safe.
176179
// If the collection doesn't exist, this returns nil.
177-
func (db *DB) GetCollection(name string) *Collection {
180+
func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
178181
db.collectionsLock.RLock()
179182
defer db.collectionsLock.RUnlock()
180-
return db.collections[name]
183+
184+
c, ok := db.collections[name]
185+
if !ok {
186+
return nil
187+
}
188+
189+
if c.embed == nil {
190+
if embeddingFunc == nil {
191+
c.embed = NewEmbeddingFuncDefault()
192+
} else {
193+
c.embed = embeddingFunc
194+
}
195+
}
196+
return c
181197
}
182198

183199
// GetOrCreateCollection returns the collection with the given name if it exists
@@ -189,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection {
189205
// Uses the default embedding function if not provided.
190206
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
191207
// No need to lock here, because the methods we call do that.
192-
collection := db.GetCollection(name)
208+
collection := db.GetCollection(name, embeddingFunc)
193209
if collection == nil {
194210
var err error
195211
collection, err = db.CreateCollection(name, metadata, embeddingFunc)

‎db_test.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func TestDB_GetCollection(t *testing.T) {
9696
}
9797

9898
// Get collection
99-
c := db.GetCollection(name)
99+
c := db.GetCollection(name, nil)
100100

101101
// Check expectations
102102
if c.Name != name {

0 commit comments

Comments
 (0)