66 "fmt"
77 "path/filepath"
88 "slices"
9+ "sort"
910 "sync"
1011)
1112
@@ -27,6 +28,73 @@ type Collection struct {
2728 // versions in [DB.Export] and [DB.Import] as well!
2829}
2930
31+ // NegativeMode represents the mode to use for the negative text.
32+ // See QueryOptions for more information.
33+ type NegativeMode string
34+
35+ const (
36+ // NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
37+ // This is the default behavior.
38+ NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"
39+
40+ // NEGATIVE_MODE_REORDER reorders the results based on the similarity between the
41+ // negative embedding and the document embeddings.
42+ // NegativeReorderStrength controls the strength of the reordering. Lower values
43+ // will reorder the results less aggressively.
44+ NEGATIVE_MODE_REORDER NegativeMode = "reorder"
45+
46+ // NEGATIVE_MODE_FILTER filters out results based on the similarity between the
47+ // negative embedding and the document embeddings.
48+ // NegativeFilterThreshold controls the threshold for filtering. Documents with
49+ // similarity above the threshold will be removed from the results.
50+ NEGATIVE_MODE_FILTER NegativeMode = "filter"
51+
52+ // Default values for negative reordering and filtering.
53+ DEFAULT_NEGATIVE_REORDER_STRENGTH = 1
54+
55+ // The default threshold for the negative filter.
56+ DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
57+ )
58+
59+ // QueryOptions represents the options for a query.
60+ type QueryOptions struct {
61+ // The text to search for.
62+ QueryText string
63+
64+ // The embedding of the query to search for. It must be created
65+ // with the same embedding model as the document embeddings in the collection.
66+ // The embedding will be normalized if it's not the case yet.
67+ // If both QueryText and QueryEmbedding are set, QueryEmbedding will be used.
68+ QueryEmbedding []float32
69+
70+ // The text to exclude from the results.
71+ NegativeText string
72+
73+ // The embedding of the negative text. It must be created
74+ // with the same embedding model as the document embeddings in the collection.
75+ // The embedding will be normalized if it's not the case yet.
76+ // If both NegativeText and NegativeEmbedding are set, NegativeEmbedding will be used.
77+ NegativeEmbedding []float32
78+
79+ // The mode to use for the negative text.
80+ NegativeMode NegativeMode
81+
82+ // The strength of the negative reordering. Used when NegativeMode is NEGATIVE_MODE_REORDER.
83+ NegativeReorderStrength float32
84+
85+ // The threshold for the negative filter. Used when NegativeMode is NEGATIVE_MODE_FILTER.
86+ NegativeFilterThreshold float32
87+
88+ // The number of results to return.
89+ NResults int
90+
91+ // Conditional filtering on metadata.
92+ Where map [string ]string
93+
94+ // Conditional filtering on documents.
95+ WhereDocument map [string ]string
96+ }
97+
3098// We don't export this yet to keep the API surface to the bare minimum.
3199// Users create collections via [Client.CreateCollection].
32100func newCollection (name string , metadata map [string ]string , embed EmbeddingFunc , dbDir string , compress bool ) (* Collection , error ) {
@@ -336,44 +404,85 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
336404 return nil , errors .New ("queryText is empty" )
337405 }
338406
339- queryVectors , err := c .embed (ctx , queryText )
407+ queryVector , err := c .embed (ctx , queryText )
340408 if err != nil {
341409 return nil , fmt .Errorf ("couldn't create embedding of query: %w" , err )
342410 }
343411
344- return c .QueryEmbedding (ctx , queryVectors , nResults , where , whereDocument )
412+ return c .QueryEmbedding (ctx , queryVector , nResults , where , whereDocument )
345413}
346414
347415// Performs an exhaustive nearest neighbor search on the collection.
348416//
349- // - queryText: The text to search for. Its embedding will be created using the
350- // collection's embedding function.
351- // - negativeText: The text to subtract from the query embedding. Its embedding
352- // will be created using the collection's embedding function.
353- // - nResults: The number of results to return. Must be > 0.
354- // - where: Conditional filtering on metadata. Optional.
355- // - whereDocument: Conditional filtering on documents. Optional.
356- func (c * Collection ) QueryWithNegative (ctx context.Context , queryText string , negativeText string , nResults int , where , whereDocument map [string ]string ) ([]Result , error ) {
357- if queryText == "" {
358- return nil , errors .New ("queryText is empty" )
417+ // - options: The options for the query. See QueryOptions for more information.
418+ func (c * Collection ) QueryWithOptions (ctx context.Context , options QueryOptions ) ([]Result , error ) {
419+ if options .QueryText == "" && len (options .QueryEmbedding ) == 0 {
420+ return nil , errors .New ("QueryText and QueryEmbedding options are empty" )
359421 }
360422
361- queryVectors , err := c .embed (ctx , queryText )
362- if err != nil {
363- return nil , fmt .Errorf ("couldn't create embedding of query: %w" , err )
423+ var err error
424+ queryVector := options .QueryEmbedding
425+ if len (queryVector ) == 0 {
426+ queryVector , err = c .embed (ctx , options .QueryText )
427+ if err != nil {
428+ return nil , fmt .Errorf ("couldn't create embedding of query: %w" , err )
429+ }
430+ }
431+
432+ negativeMode := options .NegativeMode
433+ if negativeMode == "" {
434+ negativeMode = NEGATIVE_MODE_SUBTRACT
364435 }
365436
366- if negativeText != "" {
367- negativeVectors , err := c .embed (ctx , negativeText )
437+ negativeVector := options .NegativeEmbedding
438+ if len (negativeVector ) == 0 && options .NegativeText != "" {
439+ negativeVector , err = c .embed (ctx , options .NegativeText )
368440 if err != nil {
369441 return nil , fmt .Errorf ("couldn't create embedding of negative: %w" , err )
370442 }
443+ }
444+
445+ if len (negativeVector ) != 0 {
446+ if ! isNormalized (negativeVector ) {
447+ negativeVector = normalizeVector (negativeVector )
448+ }
371449
372- queryVectors = subtractVector (queryVectors , negativeVectors )
373- queryVectors = normalizeVector (queryVectors )
450+ if negativeMode == NEGATIVE_MODE_SUBTRACT {
451+ queryVector = subtractVector (queryVector , negativeVector )
452+ queryVector = normalizeVector (queryVector )
453+ }
374454 }
375455
376- return c .QueryEmbedding (ctx , queryVectors , nResults , where , whereDocument )
456+ result , err := c .QueryEmbedding (ctx , queryVector , options .NResults , options .Where , options .WhereDocument )
457+ if err != nil {
458+ return nil , err
459+ }
460+
461+ if len (negativeVector ) != 0 {
462+ if negativeMode == NEGATIVE_MODE_REORDER {
463+ negativeReorderStrength := options .NegativeReorderStrength
464+ if negativeReorderStrength == 0 {
465+ negativeReorderStrength = DEFAULT_NEGATIVE_REORDER_STRENGTH
466+ }
467+
468+ result , err = reorderResults (result , negativeVector , negativeReorderStrength )
469+ if err != nil {
470+ return nil , fmt .Errorf ("couldn't reorder results: %w" , err )
471+ }
472+ } else if negativeMode == NEGATIVE_MODE_FILTER {
473+ negativeFilterThreshold := options .NegativeFilterThreshold
474+ if negativeFilterThreshold == 0 {
475+ negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
476+ }
477+
478+ result , err = filterResults (result , negativeVector , negativeFilterThreshold )
479+ if err != nil {
480+ return nil , fmt .Errorf ("couldn't filter results: %w" , err )
481+ }
482+ }
483+ }
484+
485+ return result , nil
377486}
378487
379488// Performs an exhaustive nearest neighbor search on the collection.
@@ -465,3 +574,45 @@ func (c *Collection) getDocPath(docID string) string {
465574 }
466575 return docPath
467576}
577+
578+ func reorderResults (results []Result , negativeVector []float32 , negativeReorderStrength float32 ) ([]Result , error ) {
579+ if len (results ) == 0 {
580+ return results , nil
581+ }
582+
583+ // Calculate cosine similarity between negative vector and each result
584+ for i := range results {
585+ sim , err := dotProduct (negativeVector , results [i ].Embedding )
586+ if err != nil {
587+ return nil , fmt .Errorf ("couldn't calculate dot product: %w" , err )
588+ }
589+ results [i ].Similarity -= sim * negativeReorderStrength
590+ }
591+
592+ // Sort results by similarity
593+ sort .Slice (results , func (i , j int ) bool {
594+ return results [i ].Similarity > results [j ].Similarity
595+ })
596+
597+ return results , nil
598+ }
599+
600+ func filterResults (results []Result , negativeVector []float32 , negativeFilterThreshold float32 ) ([]Result , error ) {
601+ if len (results ) == 0 {
602+ return results , nil
603+ }
604+
605+ // Filter out results with similarity above the threshold
606+ filteredResults := make ([]Result , 0 , len (results ))
607+ for _ , res := range results {
608+ sim , err := dotProduct (negativeVector , res .Embedding )
609+ if err != nil {
610+ return nil , fmt .Errorf ("couldn't calculate dot product: %w" , err )
611+ }
612+ if sim < negativeFilterThreshold {
613+ filteredResults = append (filteredResults , res )
614+ }
615+ }
616+
617+ return filteredResults , nil
618+ }
0 commit comments