Skip to content
Prev Previous commit
Rename some "document" to "content" or "text"
To differentiate between our now exported Document struct
and its contents.
  • Loading branch information
philippgille committed Mar 4, 2024
commit cb0fe2f8e48c7513a550867d11513963bc9958c6
26 changes: 13 additions & 13 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,27 +73,27 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
//
// - ids: The ids of the embeddings you wish to add
// - embeddings: The embeddings to add. If nil, embeddings will be computed based
// on the documents using the embeddingFunc set for the Collection. Optional.
// on the contents using the embeddingFunc set for the Collection. Optional.
// - metadatas: The metadata to associate with the embeddings. When querying,
// you can filter on this metadata. Optional.
// - documents: The documents to associate with the embeddings.
// - contents: The contents to associate with the embeddings.
//
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error {
return c.AddConcurrently(ctx, ids, embeddings, metadatas, documents, 1)
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string) error {
return c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 1)
}

// AddConcurrently is like Add, but adds embeddings concurrently.
// This is mostly useful when you don't pass any embeddings so they have to be created.
// Upon error, concurrently running operations are canceled and the error is returned.
//
// This is a Chroma-like method. For a more Go-idiomatic one, see [AddDocuments].
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error {
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, contents []string, concurrency int) error {
if len(ids) == 0 {
return errors.New("ids are empty")
}
if len(embeddings) == 0 && len(documents) == 0 {
return errors.New("either embeddings or documents must be filled")
if len(embeddings) == 0 && len(contents) == 0 {
return errors.New("either embeddings or contents must be filled")
}
if len(embeddings) != 0 {
if len(embeddings) != len(ids) {
Expand All @@ -104,15 +104,15 @@ func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddin
embeddings = make([][]float32, len(ids))
}
if len(metadatas) != 0 && len(ids) != len(metadatas) {
return errors.New("ids, metadatas and documents must have the same length")
return errors.New("ids, metadatas and contents must have the same length")
}
if len(documents) != 0 {
if len(documents) != len(ids) {
return errors.New("ids and documents must have the same length")
if len(contents) != 0 {
if len(contents) != len(ids) {
return errors.New("ids and contents must have the same length")
}
} else {
// Assign empty slice so we can simply access via index later
documents = make([]string, len(ids))
contents = make([]string, len(ids))
}
if concurrency < 1 {
return errors.New("concurrency must be at least 1")
Expand All @@ -125,7 +125,7 @@ func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddin
ID: id,
Metadata: metadatas[i],
Embedding: embeddings[i],
Content: documents[i],
Content: contents[i],
})
}

Expand Down
8 changes: 4 additions & 4 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func TestCollection_Add(t *testing.T) {
// Add document
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
documents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, documents)
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Error("expected nil, got", err)
}
Expand All @@ -54,8 +54,8 @@ func TestCollection_Count(t *testing.T) {
// Add documents
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
documents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, documents)
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Error("expected nil, got", err)
}
Expand Down
4 changes: 2 additions & 2 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
"sync"
)

// EmbeddingFunc is a function that creates embeddings for a given document.
// EmbeddingFunc is a function that creates embeddings for a given text.
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
// but you can provide your own function, using any model you like.
type EmbeddingFunc func(ctx context.Context, document string) ([]float32, error)
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)

// DB is the chromem-go database. It holds collections, which hold documents.
//
Expand Down
8 changes: 4 additions & 4 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const (
embeddingModelMistral = "mistral-embed"
)

// NewEmbeddingFuncMistral returns a function that creates embeddings for a document
// NewEmbeddingFuncMistral returns a function that creates embeddings for a text
// using the Mistral API.
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// The Mistral API docs don't mention the `encoding_format` as optional,
Expand All @@ -25,7 +25,7 @@ const (
EmbeddingModelJina2BaseZH EmbeddingModelJina = "jina-embeddings-v2-base-zh"
)

// NewEmbeddingFuncJina returns a function that creates embeddings for a document
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
Expand All @@ -46,15 +46,15 @@ const (
EmbeddingModelMixedbreadGTELargeZh EmbeddingModelMixedbread = "gte-large-zh"
)

// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a document
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
}

const baseURLLocalAI = "http://localhost:8080/v1"

// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a document
// NewEmbeddingFuncLocalAI returns a function that creates embeddings for a text
// using the LocalAI API.
// You can start a LocalAI instance like this:
//
Expand Down
8 changes: 4 additions & 4 deletions embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@ type ollamaResponse struct {
Embedding []float32 `json:"embedding"`
}

// NewEmbeddingFuncOllama returns a function that creates embeddings for a document
// NewEmbeddingFuncOllama returns a function that creates embeddings for a text
// using Ollama's embedding API. You can pass any model that Ollama supports and
// that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text".
// See https://ollama.com/library/nomic-embed-text
func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the document size.
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

return func(ctx context.Context, document string) ([]float32, error) {
return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
"model": model,
"prompt": document,
"prompt": text,
})
if err != nil {
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
Expand Down
14 changes: 7 additions & 7 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ type openAIResponse struct {
} `json:"data"`
}

// NewEmbeddingFuncDefault returns a function that creates embeddings for a document
// NewEmbeddingFuncDefault returns a function that creates embeddings for a text
// using OpenAI`s "text-embedding-3-small" model via their API.
// The model supports a maximum document length of 8191 tokens.
// The model supports a maximum text length of 8191 tokens.
// The API key is read from the environment variable "OPENAI_API_KEY".
func NewEmbeddingFuncDefault() EmbeddingFunc {
apiKey := os.Getenv("OPENAI_API_KEY")
return NewEmbeddingFuncOpenAI(apiKey, EmbeddingModelOpenAI3Small)
}

// NewEmbeddingFuncOpenAI returns a function that creates embeddings for a document
// NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text
// using the OpenAI API.
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model))
}

// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a document
// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
// using an OpenAI compatible API. For example:
// - Azure OpenAI: https://azure.microsoft.com/en-us/products/ai-services/openai-service
// - LitLLM: https://github.com/BerriAI/litellm
Expand All @@ -51,13 +51,13 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the document size.
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

return func(ctx context.Context, document string) ([]float32, error) {
return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
"input": document,
"input": text,
"model": model,
})
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions embed_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
apiKey := "secret"
model := "model-small"
baseURLSuffix := "/v1"
document := "hello world"
input := "hello world"

wantBody, err := json.Marshal(map[string]string{
"input": document,
"input": input,
"model": model,
})
if err != nil {
Expand Down Expand Up @@ -76,7 +76,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
baseURL := ts.URL + baseURLSuffix

f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model)
res, err := f(context.Background(), document)
res, err := f(context.Background(), input)
if err != nil {
t.Error("expected nil, got", err)
}
Expand Down