Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 122 additions & 10 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,67 @@ type Collection struct {
// versions in [DB.Export] and [DB.Import] as well!
}

// NegativeMode represents the mode to use for the negative text.
// See QueryOptions for more information.
type NegativeMode string

const (
// NEGATIVE_MODE_FILTER filters out results based on the similarity between the
// negative embedding and the document embeddings.
// NegativeFilterThreshold controls the threshold for filtering. Documents with
// similarity above the threshold will be removed from the results.
NEGATIVE_MODE_FILTER NegativeMode = "filter"

// NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
// This is the default behavior.
NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"

// The default threshold for the negative filter.
DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
)

// QueryOptions represents the options for a query.
type QueryOptions struct {
// The text to search for.
QueryText string

// The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// If both QueryText and QueryEmbedding are set, QueryEmbedding will be used.
QueryEmbedding []float32

// The number of results to return.
NResults int

// Conditional filtering on metadata.
Where map[string]string

// Conditional filtering on documents.
WhereDocument map[string]string

// Negative is the negative query options.
// They can be used to exclude certain results from the query.
Negative NegativeQueryOptions
}

type NegativeQueryOptions struct {
// Mode is the mode to use for the negative text.
Mode NegativeMode

// Text is the text to exclude from the results.
Text string

// Embedding is the embedding of the negative text. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// If both Text and Embedding are set, Embedding will be used.
Embedding []float32

// FilterThreshold is the threshold for the negative filter. Used when Mode is NEGATIVE_MODE_FILTER.
FilterThreshold float32
}

// We don't export this yet to keep the API surface to the bare minimum.
// Users create collections via [Client.CreateCollection].
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
Expand Down Expand Up @@ -336,12 +397,63 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
return nil, errors.New("queryText is empty")
}

queryVectors, err := c.embed(ctx, queryText)
queryVector, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}

return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
return c.QueryEmbedding(ctx, queryVector, nResults, where, whereDocument)
}

// QueryWithOptions performs an exhaustive nearest neighbor search on the collection.
//
// - options: The options for the query. See QueryOptions for more information.
func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) ([]Result, error) {
if options.QueryText == "" && len(options.QueryEmbedding) == 0 {
return nil, errors.New("QueryText and QueryEmbedding options are empty")
}

var err error
queryVector := options.QueryEmbedding
if len(queryVector) == 0 {
queryVector, err = c.embed(ctx, options.QueryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}
}

negativeFilterThreshold := options.Negative.FilterThreshold
negativeVector := options.Negative.Embedding
if len(negativeVector) == 0 && options.Negative.Text != "" {
negativeVector, err = c.embed(ctx, options.Negative.Text)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of negative: %w", err)
}
}

if len(negativeVector) != 0 {
if !isNormalized(negativeVector) {
negativeVector = normalizeVector(negativeVector)
}

if options.Negative.Mode == NEGATIVE_MODE_SUBTRACT {
queryVector = subtractVector(queryVector, negativeVector)
queryVector = normalizeVector(queryVector)
} else if options.Negative.Mode == NEGATIVE_MODE_FILTER {
if negativeFilterThreshold == 0 {
negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
}
} else {
return nil, fmt.Errorf("unsupported negative mode: %q", options.Negative.Mode)
}
}

result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument)
if err != nil {
return nil, err
}

return result, nil
}

// QueryEmbedding performs an exhaustive nearest neighbor search on the collection.
Expand All @@ -354,6 +466,11 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument)
}

// queryEmbedding performs an exhaustive nearest neighbor search on the collection.
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if len(queryEmbedding) == 0 {
return nil, errors.New("queryEmbedding is empty")
}
Expand Down Expand Up @@ -399,18 +516,13 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, negativeEmbeddings, negativeFilterThreshold, filteredDocs, resLen)
if err != nil {
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
if resLen != len(nMaxDocs) {
return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
}

res := make([]Result, 0, resLen)
for i := 0; i < resLen; i++ {
res := make([]Result, 0, len(nMaxDocs))
for i := 0; i < len(nMaxDocs); i++ {
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Expand Down
Loading