Skip to content

Commit 38898e0

Browse files
authored
Merge pull request #20 from philippgille/add-unit-tests
Add unit tests
2 parents 77094ea + 3238764 commit 38898e0

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed

‎collection_test.go‎

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package chromem_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/philippgille/chromem-go"
8+
)
9+
10+
func TestCollection_Add(t *testing.T) {
11+
// Create collection
12+
db := chromem.NewDB()
13+
name := "test"
14+
metadata := map[string]string{"foo": "bar"}
15+
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
16+
return []float32{-0.1, 0.1, 0.2}, nil
17+
}
18+
c := db.CreateCollection(name, metadata, embeddingFunc)
19+
if c == nil {
20+
t.Error("expected collection, got nil")
21+
}
22+
23+
// Add document
24+
ids := []string{"1", "2"}
25+
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
26+
documents := []string{"hello world", "hallo welt"}
27+
err := c.Add(context.Background(), ids, nil, metadatas, documents)
28+
if err != nil {
29+
t.Error("expected nil, got", err)
30+
}
31+
32+
// TODO: Check expectations when documents become accessible
33+
}

‎db_test.go‎

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,28 @@ import (
77
"github.com/philippgille/chromem-go"
88
)
99

10+
func TestDB_CreateCollection(t *testing.T) {
11+
// Values in the collection
12+
name := "test"
13+
metadata := map[string]string{"foo": "bar"}
14+
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
15+
return []float32{-0.1, 0.1, 0.2}, nil
16+
}
17+
18+
// Create collection
19+
db := chromem.NewDB()
20+
c := db.CreateCollection(name, metadata, embeddingFunc)
21+
if c == nil {
22+
t.Error("expected collection, got nil")
23+
}
24+
25+
// Check expectations
26+
if c.Name != name {
27+
t.Error("expected name", name, "got", c.Name)
28+
}
29+
// TODO: Check metadata etc when they become accessible
30+
}
31+
1032
func TestDB_ListCollections(t *testing.T) {
1133
// Values in the collection
1234
name := "test"

‎embed_openai_test.go‎

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package chromem_test
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"net/http/httptest"
10+
"slices"
11+
"strings"
12+
"testing"
13+
14+
"github.com/philippgille/chromem-go"
15+
)
16+
17+
type openAIResponse struct {
18+
Data []struct {
19+
Embedding []float32 `json:"embedding"`
20+
} `json:"data"`
21+
}
22+
23+
func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
24+
apiKey := "secret"
25+
model := "model-small"
26+
baseURLSuffix := "/v1"
27+
document := "hello world"
28+
29+
wantBody, err := json.Marshal(map[string]string{
30+
"input": document,
31+
"model": model,
32+
})
33+
if err != nil {
34+
t.Error("unexpected error:", err)
35+
}
36+
wantRes := []float32{-0.1, 0.1, 0.2}
37+
38+
// Mock server
39+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
40+
// Check URL
41+
if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") {
42+
t.Error("expected URL", baseURLSuffix+"/embedding", "got", r.URL.Path)
43+
}
44+
// Check method
45+
if r.Method != "POST" {
46+
t.Error("expected method POST, got", r.Method)
47+
}
48+
// Check headers
49+
if r.Header.Get("Authorization") != "Bearer "+apiKey {
50+
t.Error("expected Authorization header", "Bearer "+apiKey, "got", r.Header.Get("Authorization"))
51+
}
52+
if r.Header.Get("Content-Type") != "application/json" {
53+
t.Error("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type"))
54+
}
55+
// Check body
56+
body, err := io.ReadAll(r.Body)
57+
if err != nil {
58+
t.Error("unexpected error:", err)
59+
}
60+
if !bytes.Equal(body, wantBody) {
61+
t.Error("expected body", wantBody, "got", body)
62+
}
63+
64+
// Write response
65+
resp := openAIResponse{
66+
Data: []struct {
67+
Embedding []float32 `json:"embedding"`
68+
}{
69+
{Embedding: wantRes},
70+
},
71+
}
72+
w.WriteHeader(http.StatusOK)
73+
_ = json.NewEncoder(w).Encode(resp)
74+
}))
75+
defer ts.Close()
76+
baseURL := ts.URL + baseURLSuffix
77+
78+
f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model)
79+
res, err := f(context.Background(), document)
80+
if err != nil {
81+
t.Error("expected nil, got", err)
82+
}
83+
if slices.Compare[[]float32](wantRes, res) != 0 {
84+
t.Error("expected res", wantRes, "got", res)
85+
}
86+
}

0 commit comments

Comments
 (0)