Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add unit test for query errors
  • Loading branch information
philippgille committed Mar 17, 2024
commit 98516b1a6818edd05aa8967ee14d2b4ff3d2bfc4
79 changes: 79 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,85 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
}
}

func TestCollection_QueryError(t *testing.T) {
// Create collection
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}
// Add a document
err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
if err != nil {
t.Fatal("expected nil, got", err)
}

tt := []struct {
name string
query func() error
expErr string
}{
{
name: "Empty query",
query: func() error {
_, err := c.Query(context.Background(), "", 1, nil, nil)
return err
},
expErr: "queryText is empty",
},
{
name: "Negative limit",
query: func() error {
_, err := c.Query(context.Background(), "foo", -1, nil, nil)
return err
},
expErr: "nResults must be > 0",
},
{
name: "Zero limit",
query: func() error {
_, err := c.Query(context.Background(), "foo", 0, nil, nil)
return err
},
expErr: "nResults must be > 0",
},
{
name: "Limit greater than number of documents",
query: func() error {
_, err := c.Query(context.Background(), "foo", 2, nil, nil)
return err
},
expErr: "nResults must be <= the number of documents in the collection",
},
{
name: "Bad content filter",
query: func() error {
_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
return err
},
expErr: "unsupported operator",
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
err := tc.query()
if err.Error() != tc.expErr {
t.Fatal("expected", tc.expErr, "got", err)
}
})
}
}

func TestCollection_Count(t *testing.T) {
// Create collection
db := NewDB()
Expand Down