Skip to content

Commit f668a5b

Browse files
authored
Merge pull request #32 from philippgille/add-ollama-embedding
Add Ollama embedding
2 parents cabf8cc + 77245c1 commit f668a5b

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

‎embed_ollama.go‎

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package chromem
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
)
12+
13+
const baseURLOllama = "http://localhost:11434/api"
14+
15+
type ollamaResponse struct {
16+
Embedding []float32 `json:"embedding"`
17+
}
18+
19+
// NewEmbeddingFuncOllama returns a function that creates embeddings for a document
20+
// using Ollama's embedding API. You can pass any model that Ollama supports and
21+
// that supports embeddings. A good one as of 2024-03-02 is "nomic-embed-text".
22+
// See https://ollama.com/library/nomic-embed-text
23+
func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
24+
// We don't set a default timeout here, although it's usually a good idea.
25+
// In our case though, the library user can set the timeout on the context,
26+
// and it might have to be a long timeout, depending on the document size.
27+
client := &http.Client{}
28+
29+
return func(ctx context.Context, document string) ([]float32, error) {
30+
// Prepare the request body.
31+
reqBody, err := json.Marshal(map[string]string{
32+
"model": model,
33+
"prompt": document,
34+
})
35+
if err != nil {
36+
return nil, fmt.Errorf("couldn't marshal request body: %w", err)
37+
}
38+
39+
// Create the request. Creating it with context is important for a timeout
40+
// to be possible, because the client is configured without a timeout.
41+
req, err := http.NewRequestWithContext(ctx, "POST", baseURLOllama+"/embeddings", bytes.NewBuffer(reqBody))
42+
if err != nil {
43+
return nil, fmt.Errorf("couldn't create request: %w", err)
44+
}
45+
req.Header.Set("Content-Type", "application/json")
46+
47+
// Send the request.
48+
resp, err := client.Do(req)
49+
if err != nil {
50+
return nil, fmt.Errorf("couldn't send request: %w", err)
51+
}
52+
defer resp.Body.Close()
53+
54+
// Check the response status.
55+
if resp.StatusCode != http.StatusOK {
56+
return nil, errors.New("error response from the embedding API: " + resp.Status)
57+
}
58+
59+
// Read and decode the response body.
60+
body, err := io.ReadAll(resp.Body)
61+
if err != nil {
62+
return nil, fmt.Errorf("couldn't read response body: %w", err)
63+
}
64+
var embeddingResponse ollamaResponse
65+
err = json.Unmarshal(body, &embeddingResponse)
66+
if err != nil {
67+
return nil, fmt.Errorf("couldn't unmarshal response body: %w", err)
68+
}
69+
70+
// Check if the response contains embeddings.
71+
if len(embeddingResponse.Embedding) == 0 {
72+
return nil, errors.New("no embeddings found in the response")
73+
}
74+
75+
return embeddingResponse.Embedding, nil
76+
}
77+
}

0 commit comments

Comments
 (0)