Skip to content

Commit bb1407b

Browse files
Rewrote to QueryWithOptions
1 parent 9ca7474 commit bb1407b

File tree

3 files changed

+320
-20
lines changed

3 files changed

+320
-20
lines changed

‎collection.go‎

Lines changed: 171 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"path/filepath"
88
"slices"
9+
"sort"
910
"sync"
1011
)
1112

@@ -27,6 +28,73 @@ type Collection struct {
2728
// versions in [DB.Export] and [DB.Import] as well!
2829
}
2930

31+
// NegativeMode represents the mode to use for the negative text.
32+
// See QueryOptions for more information.
33+
type NegativeMode string
34+
35+
const (
36+
// NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
37+
// This is the default behavior.
38+
NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"
39+
40+
// NEGATIVE_MODE_REORDER reorders the results based on the similarity between the
41+
// negative embedding and the document embeddings.
42+
// NegativeReorderStrength controls the strength of the reordering. Lower values
43+
// will reorder the results less aggressively.
44+
NEGATIVE_MODE_REORDER NegativeMode = "reorder"
45+
46+
// NEGATIVE_MODE_FILTER filters out results based on the similarity between the
47+
// negative embedding and the document embeddings.
48+
// NegativeFilterThreshold controls the threshold for filtering. Documents with
49+
// similarity above the threshold will be removed from the results.
50+
NEGATIVE_MODE_FILTER NegativeMode = "filter"
51+
52+
// Default values for negative reordering and filtering.
53+
DEFAULT_NEGATIVE_REORDER_STRENGTH = 1
54+
55+
// The default threshold for the negative filter.
56+
DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
57+
)
58+
59+
// QueryOptions represents the options for a query.
60+
type QueryOptions struct {
61+
// The text to search for.
62+
QueryText string
63+
64+
// The embedding of the query to search for. It must be created
65+
// with the same embedding model as the document embeddings in the collection.
66+
// The embedding will be normalized if it's not the case yet.
67+
// If both QueryText and QueryEmbedding are set, QueryEmbedding will be used.
68+
QueryEmbedding []float32
69+
70+
// The text to exclude from the results.
71+
NegativeText string
72+
73+
// The embedding of the negative text. It must be created
74+
// with the same embedding model as the document embeddings in the collection.
75+
// The embedding will be normalized if it's not the case yet.
76+
// If both NegativeText and NegativeEmbedding are set, NegativeEmbedding will be used.
77+
NegativeEmbedding []float32
78+
79+
// The mode to use for the negative text.
80+
NegativeMode NegativeMode
81+
82+
// The strength of the negative reordering. Used when NegativeMode is NEGATIVE_MODE_REORDER.
83+
NegativeReorderStrength float32
84+
85+
// The threshold for the negative filter. Used when NegativeMode is NEGATIVE_MODE_FILTER.
86+
NegativeFilterThreshold float32
87+
88+
// The number of results to return.
89+
NResults int
90+
91+
// Conditional filtering on metadata.
92+
Where map[string]string
93+
94+
// Conditional filtering on documents.
95+
WhereDocument map[string]string
96+
}
97+
3098
// We don't export this yet to keep the API surface to the bare minimum.
3199
// Users create collections via [Client.CreateCollection].
32100
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
@@ -336,44 +404,85 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
336404
return nil, errors.New("queryText is empty")
337405
}
338406

339-
queryVectors, err := c.embed(ctx, queryText)
407+
queryVector, err := c.embed(ctx, queryText)
340408
if err != nil {
341409
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
342410
}
343411

344-
return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
412+
return c.QueryEmbedding(ctx, queryVector, nResults, where, whereDocument)
345413
}
346414

347415
// Performs an exhaustive nearest neighbor search on the collection.
348416
//
349-
// - queryText: The text to search for. Its embedding will be created using the
350-
// collection's embedding function.
351-
// - negativeText: The text to subtract from the query embedding. Its embedding
352-
// will be created using the collection's embedding function.
353-
// - nResults: The number of results to return. Must be > 0.
354-
// - where: Conditional filtering on metadata. Optional.
355-
// - whereDocument: Conditional filtering on documents. Optional.
356-
func (c *Collection) QueryWithNegative(ctx context.Context, queryText string, negativeText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
357-
if queryText == "" {
358-
return nil, errors.New("queryText is empty")
417+
// - options: The options for the query. See QueryOptions for more information.
418+
func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) ([]Result, error) {
419+
if options.QueryText == "" && len(options.QueryEmbedding) == 0 {
420+
return nil, errors.New("QueryText and QueryEmbedding options are empty")
359421
}
360422

361-
queryVectors, err := c.embed(ctx, queryText)
362-
if err != nil {
363-
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
423+
var err error
424+
queryVector := options.QueryEmbedding
425+
if len(queryVector) == 0 {
426+
queryVector, err = c.embed(ctx, options.QueryText)
427+
if err != nil {
428+
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
429+
}
430+
}
431+
432+
negativeMode := options.NegativeMode
433+
if negativeMode == "" {
434+
negativeMode = NEGATIVE_MODE_SUBTRACT
364435
}
365436

366-
if negativeText != "" {
367-
negativeVectors, err := c.embed(ctx, negativeText)
437+
negativeVector := options.NegativeEmbedding
438+
if len(negativeVector) == 0 && options.NegativeText != "" {
439+
negativeVector, err = c.embed(ctx, options.NegativeText)
368440
if err != nil {
369441
return nil, fmt.Errorf("couldn't create embedding of negative: %w", err)
370442
}
443+
}
444+
445+
if len(negativeVector) != 0 {
446+
if !isNormalized(negativeVector) {
447+
negativeVector = normalizeVector(negativeVector)
448+
}
371449

372-
queryVectors = subtractVector(queryVectors, negativeVectors)
373-
queryVectors = normalizeVector(queryVectors)
450+
if negativeMode == NEGATIVE_MODE_SUBTRACT {
451+
queryVector = subtractVector(queryVector, negativeVector)
452+
queryVector = normalizeVector(queryVector)
453+
}
374454
}
375455

376-
return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
456+
result, err := c.QueryEmbedding(ctx, queryVector, options.NResults, options.Where, options.WhereDocument)
457+
if err != nil {
458+
return nil, err
459+
}
460+
461+
if len(negativeVector) != 0 {
462+
if negativeMode == NEGATIVE_MODE_REORDER {
463+
negativeReorderStrength := options.NegativeReorderStrength
464+
if negativeReorderStrength == 0 {
465+
negativeReorderStrength = DEFAULT_NEGATIVE_REORDER_STRENGTH
466+
}
467+
468+
result, err = reorderResults(result, negativeVector, negativeReorderStrength)
469+
if err != nil {
470+
return nil, fmt.Errorf("couldn't reorder results: %w", err)
471+
}
472+
} else if negativeMode == NEGATIVE_MODE_FILTER {
473+
negativeFilterThreshold := options.NegativeFilterThreshold
474+
if negativeFilterThreshold == 0 {
475+
negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
476+
}
477+
478+
result, err = filterResults(result, negativeVector, negativeFilterThreshold)
479+
if err != nil {
480+
return nil, fmt.Errorf("couldn't filter results: %w", err)
481+
}
482+
}
483+
}
484+
485+
return result, nil
377486
}
378487

379488
// Performs an exhaustive nearest neighbor search on the collection.
@@ -465,3 +574,45 @@ func (c *Collection) getDocPath(docID string) string {
465574
}
466575
return docPath
467576
}
577+
578+
func reorderResults(results []Result, negativeVector []float32, negativeReorderStrength float32) ([]Result, error) {
579+
if len(results) == 0 {
580+
return results, nil
581+
}
582+
583+
// Calculate cosine similarity between negative vector and each result
584+
for i := range results {
585+
sim, err := dotProduct(negativeVector, results[i].Embedding)
586+
if err != nil {
587+
return nil, fmt.Errorf("couldn't calculate dot product: %w", err)
588+
}
589+
results[i].Similarity -= sim * negativeReorderStrength
590+
}
591+
592+
// Sort results by similarity
593+
sort.Slice(results, func(i, j int) bool {
594+
return results[i].Similarity > results[j].Similarity
595+
})
596+
597+
return results, nil
598+
}
599+
600+
func filterResults(results []Result, negativeVector []float32, negativeFilterThreshold float32) ([]Result, error) {
601+
if len(results) == 0 {
602+
return results, nil
603+
}
604+
605+
// Filter out results with similarity above the threshold
606+
filteredResults := make([]Result, 0, len(results))
607+
for _, res := range results {
608+
sim, err := dotProduct(negativeVector, res.Embedding)
609+
if err != nil {
610+
return nil, fmt.Errorf("couldn't calculate dot product: %w", err)
611+
}
612+
if sim < negativeFilterThreshold {
613+
filteredResults = append(filteredResults, res)
614+
}
615+
}
616+
617+
return filteredResults, nil
618+
}

0 commit comments

Comments
 (0)