Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
change: move OpenAICompat with custom headers/query-params to unexpor…
…ted function
  • Loading branch information
iwilltry42 committed May 13, 2024
commit 88f1efa639f5d41b0bc29f0d93cb0942ddb0643d
10 changes: 5 additions & 5 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {

// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized, nil, nil)
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
}

const baseURLJina = "https://api.jina.ai/v1"
Expand All @@ -32,7 +32,7 @@ const (
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil, nil, nil)
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
}

const baseURLMixedbread = "https://api.mixedbread.ai"
Expand All @@ -53,7 +53,7 @@ const (
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil, nil, nil)
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
}

const baseURLLocalAI = "http://localhost:8080/v1"
Expand All @@ -68,7 +68,7 @@ const baseURLLocalAI = "http://localhost:8080/v1"
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil, nil, nil)
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
}

const (
Expand All @@ -83,5 +83,5 @@ func NewEmbeddingFuncAzureOpenAI(apiKey string, deploymentURL string, apiVersion
if apiVersion == "" {
apiVersion = azureDefaultAPIVersion
}
return NewEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, map[string]string{"api-key": apiKey}, map[string]string{"api-version": apiVersion})
return newEmbeddingFuncOpenAICompat(deploymentURL, apiKey, model, nil, map[string]string{"api-key": apiKey}, map[string]string{"api-version": apiVersion})
}
17 changes: 15 additions & 2 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewEmbeddingFuncDefault() EmbeddingFunc {
func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc {
// OpenAI embeddings are normalized
normalized := true
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized, nil, nil)
return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized)
}

// NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
Expand All @@ -56,7 +56,20 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding
// model are already normalized, as is the case for OpenAI's and Mistral's models.
// The flag is optional. If it's nil, it will be autodetected on the first request
// (which bears a small risk that the vector just happens to have a length of 1).
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, headers map[string]string, queryParams map[string]string) EmbeddingFunc {
func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc {
return newEmbeddingFuncOpenAICompat(baseURL, apiKey, model, normalized, nil, nil)
}

// newEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text
// using an OpenAI compatible API.
// It offers options to set request headers and query parameters
// e.g. to pass the `api-key` header and the `api-version` query parameter for Azure OpenAI.
//
// The `normalized` parameter indicates whether the vectors returned by the embedding
// model are already normalized, as is the case for OpenAI's and Mistral's models.
// The flag is optional. If it's nil, it will be autodetected on the first request
// (which bears a small risk that the vector just happens to have a length of 1).
func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool, headers map[string]string, queryParams map[string]string) EmbeddingFunc {
// We don't set a default timeout here, although it's usually a good idea.
// In our case though, the library user can set the timeout on the context,
// and it might have to be a long timeout, depending on the text length.
Expand Down
2 changes: 1 addition & 1 deletion embed_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) {
defer ts.Close()
baseURL := ts.URL + baseURLSuffix

f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil, nil, nil)
f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil)
res, err := f(context.Background(), input)
if err != nil {
t.Fatal("expected nil, got", err)
Expand Down