Skip to content

Commit 9b47246

Browse files
authored
Merge pull request #51 from philippgille/fix-param-validation
Fix param validation
2 parents efb6890 + 98516b1 commit 9b47246

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

‎collection.go‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
279279
}
280280
c.documentsLock.RLock()
281281
defer c.documentsLock.RUnlock()
282-
if nResults < len(c.documents) {
283-
return nil, errors.New("nResults must be greater than the number of documents in the collection")
282+
if nResults > len(c.documents) {
283+
return nil, errors.New("nResults must be <= the number of documents in the collection")
284284
}
285285

286286
if len(c.documents) == 0 {

‎collection_test.go‎

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,85 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
308308
}
309309
}
310310

311+
func TestCollection_QueryError(t *testing.T) {
312+
// Create collection
313+
db := NewDB()
314+
name := "test"
315+
metadata := map[string]string{"foo": "bar"}
316+
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
317+
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
318+
return vectors, nil
319+
}
320+
c, err := db.CreateCollection(name, metadata, embeddingFunc)
321+
if err != nil {
322+
t.Fatal("expected no error, got", err)
323+
}
324+
if c == nil {
325+
t.Fatal("expected collection, got nil")
326+
}
327+
// Add a document
328+
err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
329+
if err != nil {
330+
t.Fatal("expected nil, got", err)
331+
}
332+
333+
tt := []struct {
334+
name string
335+
query func() error
336+
expErr string
337+
}{
338+
{
339+
name: "Empty query",
340+
query: func() error {
341+
_, err := c.Query(context.Background(), "", 1, nil, nil)
342+
return err
343+
},
344+
expErr: "queryText is empty",
345+
},
346+
{
347+
name: "Negative limit",
348+
query: func() error {
349+
_, err := c.Query(context.Background(), "foo", -1, nil, nil)
350+
return err
351+
},
352+
expErr: "nResults must be > 0",
353+
},
354+
{
355+
name: "Zero limit",
356+
query: func() error {
357+
_, err := c.Query(context.Background(), "foo", 0, nil, nil)
358+
return err
359+
},
360+
expErr: "nResults must be > 0",
361+
},
362+
{
363+
name: "Limit greater than number of documents",
364+
query: func() error {
365+
_, err := c.Query(context.Background(), "foo", 2, nil, nil)
366+
return err
367+
},
368+
expErr: "nResults must be <= the number of documents in the collection",
369+
},
370+
{
371+
name: "Bad content filter",
372+
query: func() error {
373+
_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
374+
return err
375+
},
376+
expErr: "unsupported operator",
377+
},
378+
}
379+
380+
for _, tc := range tt {
381+
t.Run(tc.name, func(t *testing.T) {
382+
err := tc.query()
383+
if err.Error() != tc.expErr {
384+
t.Fatal("expected", tc.expErr, "got", err)
385+
}
386+
})
387+
}
388+
}
389+
311390
func TestCollection_Count(t *testing.T) {
312391
// Create collection
313392
db := NewDB()

0 commit comments

Comments
 (0)