Skip to content
Prev Previous commit
Next Next commit
Change exported input type and internal handling
- External: So users don't have to cast between the "enum" type and string
- Internal: So we don't have to add the colon and space each time
  • Loading branch information
philippgille committed Mar 23, 2024
commit 0ca83330759ac9c9e901f54c403a962d76e134fa
33 changes: 20 additions & 13 deletions embed_cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,29 @@ const (
EmbeddingModelCohereEnglishV3 EmbeddingModelCohere = "embed-english-v3.0"
)

type InputTypeCohere string
// Prefixes for external use.
const (
InputTypeCohereSearchDocumentPrefix string = "search_document: "
InputTypeCohereSearchQueryPrefix string = "search_query: "
InputTypeCohereClassificationPrefix string = "classification: "
InputTypeCohereClusteringPrefix string = "clustering: "
)

// Input types for internal use.
const (
InputTypeCohereSearchDocument InputTypeCohere = "search_document"
InputTypeCohereSearchQuery InputTypeCohere = "search_query"
InputTypeCohereClassification InputTypeCohere = "classification"
InputTypeCohereClustering InputTypeCohere = "clustering"
inputTypeCohereSearchDocument string = "search_document"
inputTypeCohereSearchQuery string = "search_query"
inputTypeCohereClassification string = "classification"
inputTypeCohereClustering string = "clustering"
)

const baseURLCohere = "https://api.cohere.ai/v1"

var validInputTypesCohere = []string{
string(InputTypeCohereSearchDocument),
string(InputTypeCohereSearchQuery),
string(InputTypeCohereClassification),
string(InputTypeCohereClustering),
var validInputTypesCohere = map[string]string{
inputTypeCohereSearchDocument: InputTypeCohereSearchDocumentPrefix,
inputTypeCohereSearchQuery: InputTypeCohereSearchQueryPrefix,
inputTypeCohereClassification: InputTypeCohereClassificationPrefix,
inputTypeCohereClustering: InputTypeCohereClusteringPrefix,
}

type cohereResponse struct {
Expand Down Expand Up @@ -71,10 +78,10 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding

return func(ctx context.Context, text string) ([]float32, error) {
var inputType string
for _, validInputType := range validInputTypesCohere {
if strings.HasPrefix(text, validInputType+": ") {
for validInputType, validInputTypePrefix := range validInputTypesCohere {
if strings.HasPrefix(text, validInputTypePrefix) {
inputType = validInputType
text = strings.TrimPrefix(text, validInputType+": ")
text = strings.TrimPrefix(text, validInputTypePrefix)
break
}
}
Expand Down