Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ func (c *DB) GetCollection(name string) *Collection {
return c.collections[name]
}

// GetOrCreateCollection returns the collection with the given name if it exists
// in the DB, or otherwise creates it. When creating:
//
// - name: The name of the collection to create.
// - metadata: Optional metadata to associate with the collection.
// - embeddingFunc: Optional function to use to embed documents.
// Uses the default embedding function if not provided.
func (c *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) *Collection {
// No need to lock here, because the methods we call do that.
collection := c.GetCollection(name)
if collection == nil {
collection = c.CreateCollection(name, metadata, embeddingFunc)
}
return collection
}

// DeleteCollection deletes the collection with the given name.
// If the collection doesn't exist, this is a no-op.
func (c *DB) DeleteCollection(name string) {
Expand Down
49 changes: 49 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,55 @@ func TestDB_GetCollection(t *testing.T) {
// TODO: Check documents map being a copy as soon as we have access to it
}

func TestDB_GetOrCreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{-0.1, 0.1, 0.2}, nil
}

t.Run("Get", func(t *testing.T) {
// Create initial collection
db := chromem.NewDB()
// Create collection so that the GetOrCreateCollection() call below only
// gets it.
// We ignore the return value. CreateCollection is tested elsewhere.
_ = db.CreateCollection(name, metadata, embeddingFunc)

// Call GetOrCreateCollection() with the same name to only get it. We pass
// nil for the metadata and embeddingFunc so we can check that the returned
// collection is the original one, and not a new one.
c := db.GetOrCreateCollection(name, nil, nil)
if c == nil {
t.Error("expected collection, got nil")
}

// Check expectations
if c.Name != name {
t.Error("expected name", name, "got", c.Name)
}
// TODO: Check metadata when it's accessible (e.g. with GetMetadata())
})

t.Run("Create", func(t *testing.T) {
// Create initial collection
db := chromem.NewDB()

// Call GetOrCreateCollection()
c := db.GetOrCreateCollection(name, metadata, embeddingFunc)
if c == nil {
t.Error("expected collection, got nil")
}

// Check like we check CreateCollection()
if c.Name != name {
t.Error("expected name", name, "got", c.Name)
}
// TODO: Check metadata when it's accessible (e.g. with GetMetadata())
})
}

func TestDB_DeleteCollection(t *testing.T) {
// Values in the collection
name := "test"
Expand Down