Skip to content

Commit f971ad9

Browse files
authored
Merge pull request #61 from philippgille/add-cohere-embedding-provider
Add Cohere embedding provider
2 parents bb2271e + 8854814 commit f971ad9

File tree

1 file changed

+167
-0
lines changed

1 file changed

+167
-0
lines changed

‎embed_cohere.go‎

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package chromem
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"strings"
12+
"sync"
13+
)
14+
15+
type EmbeddingModelCohere string
16+
17+
const (
18+
EmbeddingModelCohereMultilingualV2 EmbeddingModelCohere = "embed-multilingual-v2.0"
19+
EmbeddingModelCohereEnglishLightV2 EmbeddingModelCohere = "embed-english-light-v2.0"
20+
EmbeddingModelCohereEnglishV2 EmbeddingModelCohere = "embed-english-v2.0"
21+
EmbeddingModelCohereMultilingualLightV3 EmbeddingModelCohere = "embed-multilingual-light-v3.0"
22+
EmbeddingModelCohereEnglishLightV3 EmbeddingModelCohere = "embed-english-light-v3.0"
23+
EmbeddingModelCohereMultilingualV3 EmbeddingModelCohere = "embed-multilingual-v3.0"
24+
EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0"
25+
)
26+
27+
// Prefixes for external use.
28+
const (
29+
InputTypeCohereSearchDocumentPrefix string = "search_document: "
30+
InputTypeCohereSearchQueryPrefix string = "search_query: "
31+
InputTypeCohereClassificationPrefix string = "classification: "
32+
InputTypeCohereClusteringPrefix string = "clustering: "
33+
)
34+
35+
// Input types for internal use.
36+
const (
37+
inputTypeCohereSearchDocument string = "search_document"
38+
inputTypeCohereSearchQuery string = "search_query"
39+
inputTypeCohereClassification string = "classification"
40+
inputTypeCohereClustering string = "clustering"
41+
)
42+
43+
const baseURLCohere = "https://api.cohere.ai/v1"
44+
45+
var validInputTypesCohere = map[string]string{
46+
inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix,
47+
inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix,
48+
inputTypeCohereClassification: InputTypeCohereClassificationPrefix,
49+
inputTypeCohereClustering: InputTypeCohereClusteringPrefix,
50+
}
51+
52+
type cohereResponse struct {
53+
Embeddings [][]float32 `json:"embeddings"`
54+
}
55+
56+
// NewEmbeddingFuncCohere returns a function that creates embeddings for a text
57+
// using Cohere's API. One important difference to OpenAI's and other's APIs is
58+
// that Cohere differentiates between document embeddings and search/query embeddings.
59+
// In order for this embedding func to do the differentiation, you have to prepend
60+
// the text with either "search_document" or "search_query". We'll cut off that
61+
// prefix before sending the document/query body to the API, we'll just use the
62+
// prefix to choose the right "input type" as they call it.
63+
//
64+
// When you set up a chromem-go collection with this embedding function, you might
65+
// want to create the document separately with [NewDocument] and then cut off the
66+
// prefix before adding the document to the collection. Otherwise when you query
67+
// the collection, the returned documents will still have the prefix in their content.
68+
//
69+
// cohereFunc := chromem.NewEmbeddingFuncCohere(cohereApiKey, chromem.EmbeddingModelCohereEnglishV3)
70+
// content := "The sky is blue because of Rayleigh scattering."
71+
// // Create the document with the prefix.
72+
// contentWithPrefix := chromem.InputTypeCohereSearchDocumentPrefix + content
73+
// doc, _ := NewDocument(ctx, id, metadata, nil, contentWithPrefix, cohereFunc)
74+
// // Remove the prefix so that later query results don't have it.
75+
// doc.Content = content
76+
// _ = collection.AddDocument(ctx, doc)
77+
//
78+
// This is not necessary if you don't keep the content in the documents, as chromem-go
79+
// also works when documents only have embeddings.
80+
// You can also keep the prefix in the document, and only remove it after querying.
81+
//
82+
// We plan to improve this in the future.
83+
func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) EmbeddingFunc {
84+
// We don't set a default timeout here, although it's usually a good idea.
85+
// In our case though, the library user can set the timeout on the context,
86+
// and it might have to be a long timeout, depending on the text length.
87+
client := &http.Client{}
88+
89+
var checkedNormalized bool
90+
checkNormalized := sync.Once{}
91+
92+
return func(ctx context.Context, text string) ([]float32, error) {
93+
var inputType string
94+
for validInputType, validInputTypePrefix := range validInputTypesCohere {
95+
if strings.HasPrefix(text, validInputTypePrefix) {
96+
inputType = validInputType
97+
text = strings.TrimPrefix(text, validInputTypePrefix)
98+
break
99+
}
100+
}
101+
if inputType == "" {
102+
return nil, errors.New("text must start with a valid input type plus colon and space")
103+
}
104+
105+
// Prepare the request body.
106+
reqBody, err := json.Marshal(map[string]any{
107+
"model": model,
108+
"texts": []string{text},
109+
"input_type": inputType,
110+
})
111+
if err != nil {
112+
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
113+
}
114+
115+
// Create the request. Creating it with context is important for a timeout
116+
// to be possible, because the client is configured without a timeout.
117+
req, err := http.NewRequestWithContext(ctx, "POST", baseURLCohere+"/embed", bytes.NewBuffer(reqBody))
118+
if err != nil {
119+
return nil, fmt.Errorf("couldn't create request: %w", err)
120+
}
121+
req.Header.Set("Accept", "application/json")
122+
req.Header.Set("Content-Type", "application/json")
123+
req.Header.Set("Authorization", "Bearer "+apiKey)
124+
125+
// Send the request.
126+
resp, err := client.Do(req)
127+
if err != nil {
128+
return nil, fmt.Errorf("couldn't send request: %w", err)
129+
}
130+
defer resp.Body.Close()
131+
132+
// Check the response status.
133+
if resp.StatusCode != http.StatusOK {
134+
return nil, errors.New("error response from the embedding API: " + resp.Status)
135+
}
136+
137+
// Read and decode the response body.
138+
body, err := io.ReadAll(resp.Body)
139+
if err != nil {
140+
return nil, fmt.Errorf("couldn't read response body: %w", err)
141+
}
142+
var embeddingResponse cohereResponse
143+
err = json.Unmarshal(body, &embeddingResponse)
144+
if err != nil {
145+
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
146+
}
147+
148+
// Check if the response contains embeddings.
149+
if len(embeddingResponse.Embeddings) == 0 || len(embeddingResponse.Embeddings[0]) == 0 {
150+
return nil, errors.New("no embeddings found in the response")
151+
}
152+
153+
v := embeddingResponse.Embeddings[0]
154+
checkNormalized.Do(func() {
155+
if isNormalized(v) {
156+
checkedNormalized = true
157+
} else {
158+
checkedNormalized = false
159+
}
160+
})
161+
if !checkedNormalized {
162+
v = normalizeVector(v)
163+
}
164+
165+
return v, nil
166+
}
167+
}

0 commit comments

Comments
 (0)