Add Azure AI Rerank support#129848
Conversation
3c5f0eb to
9ba481d
Compare
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
| return completionModel; | ||
| } | ||
|
|
||
| if (taskType == TaskType.RERANK) { |
There was a problem hiding this comment.
Can we simplify the logic in this method with a switch statement to make the model and then a single call to checkProviderAndEndpointTypeForTask?
There was a problem hiding this comment.
It will require casting for the service settings in the checkProviderAndEndpointTypeForTask so it won't be one single call unfortunately
There was a problem hiding this comment.
Can we not cast it to AzureAiStudioServiceSettings? Something like:
AzureAiStudioModel model;
switch(taskType) {
case TEXT_EMBEDDING -> {
model = new AzureAiStudioEmbeddingsModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
taskSettings,
chunkingSettings,
secretSettings,
context
);
}
case COMPLETION -> {
...
}
default -> throw new ElasticsearchStatusException(
failureMessage,
RestStatus.BAD_REQUEST
);
}
AzureAiStudioServiceSettings azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
checkProviderAndEndpointTypeForTask(
taskType,
azureAiStudioServiceSettings.provider(),
azureAiStudioServiceSettings.endpointType()
);
There was a problem hiding this comment.
Well. My bad. Refactored.
Thank you
| // these providers have chat completion inference (all providers at the moment) | ||
| public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values()); | ||
|
|
||
| // these providers allow token ("pay as you go") embeddings endpoints |
There was a problem hiding this comment.
Can you clarify what this comment means? Why do they have to support token ("pay as you go") billing to be valid for rerank? Why are they embeddings endpoints instead of rerank endpoints? Can this just say // these providers have rerank inference?
There was a problem hiding this comment.
Cohere Rerank billing is based on search_units. So effectively we pay for what we use. That's why I think ""pay as you go" is applicable. All in all this comment is not relevant since the constant with it has been deleted as suggested in other comment. Thanks
| public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE); | ||
|
|
||
| // these providers allow token ("pay as you go") embeddings endpoints | ||
| public static final List<AzureAiStudioProvider> tokenRerankProviders = List.of(AzureAiStudioProvider.COHERE); |
There was a problem hiding this comment.
Why does this need to be split from rerankProviders? Do we expect there to be rerankProviders that don't offer token rerank capabilities?
There was a problem hiding this comment.
Currently we use only Cohere for reranking so that's wiser not to split. Thank you. Done
| public static final List<AzureAiStudioProvider> tokenRerankProviders = List.of(AzureAiStudioProvider.COHERE); | ||
|
|
||
| // these providers allow realtime rerank endpoints (none at the moment) | ||
| public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of(); |
There was a problem hiding this comment.
Do we suspect these will be added at some point? If not do we need to have this in code until any are added?
There was a problem hiding this comment.
I think it should be left.
It's used to check if the realtime type is valid for a provider allowing to show more descriptive error
e.g. we try to set realtime as an endpoint_type while creating an inference endpoint:
"The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available"
There was a problem hiding this comment.
Makes sense, thanks for clarifying!
| } | ||
|
|
||
| private AzureAiStudioRerankRequestEntity createRequestEntity() { | ||
| var taskSettings = rerankModel.getTaskSettings(); |
There was a problem hiding this comment.
Nit: Do we need this stored in a separate variable if it's only referenced once?
There was a problem hiding this comment.
Agree. That's a bit cleaner not to introduce it, corrected. Thx
|
|
||
| private static AzureAiStudioRerankTaskSettings createRandom() { | ||
| return new AzureAiStudioRerankTaskSettings( | ||
| randomFrom(randomFrom(new Boolean[] { null, randomBoolean() })), |
There was a problem hiding this comment.
Do we need the second randomFrom here?
There was a problem hiding this comment.
Well. That's overdo. I removed unnecessary randomFrom. Thank you
| ? tokenEmbeddingsProviders.contains(provider) | ||
| : realtimeEmbeddingsProviders.contains(provider); | ||
| } | ||
| case RERANK -> { |
There was a problem hiding this comment.
Is this logic covered in testing anywhere right now?
There was a problem hiding this comment.
Yep. It is covered here:
AzureAiStudioServiceTests#testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider
There was a problem hiding this comment.
My mistake, thanks for clarifying.
| assertThat(settings, is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS)); | ||
| } | ||
|
|
||
| public void testFromMap_ReturnsDoSample() { |
There was a problem hiding this comment.
What does this test name mean by DoSample?
There was a problem hiding this comment.
I referenced AzureAiStudioChatCompletionRequestTaskSettingsTests and missed to rename the test method. Apologies. Renamed
|
|
||
| public void testFromMap_ReturnsDoSample() { | ||
| final var settings = AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(RETURN_DOCUMENTS_FIELD, true))); | ||
| assertThat(settings.returnDocuments(), is(true)); |
There was a problem hiding this comment.
Should we just compare the settings here to an expected settings object similar to the tests above?
There was a problem hiding this comment.
It seems better. I did it. Thank you
| assertThat(requestMap.get(INPUT), is(List.of(input))); | ||
| } | ||
|
|
||
| public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopNParam() throws IOException { |
There was a problem hiding this comment.
Can we have a test for creating the request with the other parameter?
There was a problem hiding this comment.
Sure. Added the test for the return documents parameter. Thanks
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
|
Pinging @elastic/ml-core (Team:ML) |
| public void testUpdatedTaskSettings_WithAllValues() { | ||
| final AzureAiStudioRerankTaskSettings initialSettings = createRandom(); | ||
| AzureAiStudioRerankTaskSettings newSettings; | ||
| int retries = 0; |
There was a problem hiding this comment.
Instead of running retries which can require multiple loops can we just create the newSettings objects using randomValueOtherThan ourselves:
AzureAiStudioRerankTaskSettings newSettings = new AzureAiStudioRerankTaskSettings(randomValueOtherThan(intialSettings.returnDocuments(), () -> randomFrom(new Boolean[] { null, randomBoolean() }), randomValueOtherThan(initialSettings.topN(), () ->
randomFrom(new Integer[] { null, randomNonNegativeInt() })));
Adjust the above based on whether we want to allow null values. The same comment applies to the other tests below. If we find we're reusing this we can create a helper function to do this for us across the various tests.
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
# Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
As of now, it appearsthat only the Cohere provider is applicable for reranking
Azure AI Foundry Models available for standard deployment
Cohere docs
PUT {{base-url}}/_inference/rerank/cohere
{
"service": "azureaistudio",
"service_settings": {
"target": "https://Cohere-rerank-v3-5-samwq.swedencentral.models.ai.azure.com",
"provider": "COHERE",
"endpoint_type": "token",
"api_key": "{{cohere-api-key}}"
},
"task_settings": {
"top_n": 2,
"return_documents": true
}
}
POST {{base-url}}/_inference/rerank/cohere
{
"input": ["Luke", "like", "leia", "chewy","r2d2", "star", "wars"],
"query": "star wars main character",
"top_n": 7,
"return_documents": true
}