Skip to content

Add Azure AI Rerank support#129848

Merged
dan-rubinstein merged 15 commits intoelastic:mainfrom
Evgenii-Kazannik:Add-Azure-AI-Foundry-Rerank-support
Jul 17, 2025
Merged

Add Azure AI Rerank support#129848
dan-rubinstein merged 15 commits intoelastic:mainfrom
Evgenii-Kazannik:Add-Azure-AI-Foundry-Rerank-support

Conversation

@Evgenii-Kazannik
Copy link
Contributor

@Evgenii-Kazannik Evgenii-Kazannik commented Jun 23, 2025

As of now, it appearsthat only the Cohere provider is applicable for reranking
Azure AI Foundry Models available for standard deployment
Cohere docs

PUT {{base-url}}/_inference/rerank/cohere
{
"service": "azureaistudio",
"service_settings": {
"target": "https://Cohere-rerank-v3-5-samwq.swedencentral.models.ai.azure.com",
"provider": "COHERE",
"endpoint_type": "token",
"api_key": "{{cohere-api-key}}"
},
"task_settings": {
"top_n": 2,
"return_documents": true
}
}


POST {{base-url}}/_inference/rerank/cohere
{

"input": ["Luke", "like", "leia", "chewy","r2d2", "star", "wars"],
"query": "star wars main character",
"top_n": 7,
"return_documents": true
}

@elasticsearchmachine elasticsearchmachine added v9.1.0 external-contributor Pull request authored by a developer outside the Elasticsearch team labels Jun 23, 2025
@Evgenii-Kazannik Evgenii-Kazannik force-pushed the Add-Azure-AI-Foundry-Rerank-support branch from 3c5f0eb to 9ba481d Compare June 25, 2025 09:02
@Evgenii-Kazannik Evgenii-Kazannik marked this pull request as ready for review June 25, 2025 09:16
@elasticsearchmachine elasticsearchmachine added needs:triage Requires assignment of a team area label v9.2.0 and removed v9.1.0 labels Jun 25, 2025
# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
return completionModel;
}

if (taskType == TaskType.RERANK) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify the logic in this method with a switch statement to make the model and then a single call to checkProviderAndEndpointTypeForTask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will require casting for the service settings in the checkProviderAndEndpointTypeForTask so it won't be one single call unfortunately

Copy link
Member

@dan-rubinstein dan-rubinstein Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not cast it to AzureAiStudioServiceSettings? Something like:

        AzureAiStudioModel model;
        switch(taskType) {
            case TEXT_EMBEDDING -> {
                model = new AzureAiStudioEmbeddingsModel(
                    inferenceEntityId,
                    taskType,
                    NAME,
                    serviceSettings,
                    taskSettings,
                    chunkingSettings,
                    secretSettings,
                    context
                );
            }
            case COMPLETION -> {
                ...
            }
            default -> throw new ElasticsearchStatusException(
                failureMessage,
                RestStatus.BAD_REQUEST
            );
        }
        AzureAiStudioServiceSettings azureAiStudioServiceSettings = (AzureAiStudioServiceSettings) model.getServiceSettings();
        checkProviderAndEndpointTypeForTask(
            taskType,
            azureAiStudioServiceSettings.provider(),
            azureAiStudioServiceSettings.endpointType()
        );
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well. My bad. Refactored.
Thank you

// these providers have chat completion inference (all providers at the moment)
public static final List<AzureAiStudioProvider> chatCompletionProviders = List.of(AzureAiStudioProvider.values());

// these providers allow token ("pay as you go") embeddings endpoints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what this comment means? Why do they have to support token ("pay as you go") billing to be valid for rerank? Why are they embeddings endpoints instead of rerank endpoints? Can this just say // these providers have rerank inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cohere Rerank billing is based on search_units. So effectively we pay for what we use. That's why I think ""pay as you go" is applicable. All in all this comment is not relevant since the constant with it has been deleted as suggested in other comment. Thanks

public static final List<AzureAiStudioProvider> rerankProviders = List.of(AzureAiStudioProvider.COHERE);

// these providers allow token ("pay as you go") embeddings endpoints
public static final List<AzureAiStudioProvider> tokenRerankProviders = List.of(AzureAiStudioProvider.COHERE);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be split from rerankProviders? Do we expect there to be rerankProviders that don't offer token rerank capabilities?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we use only Cohere for reranking so that's wiser not to split. Thank you. Done

public static final List<AzureAiStudioProvider> tokenRerankProviders = List.of(AzureAiStudioProvider.COHERE);

// these providers allow realtime rerank endpoints (none at the moment)
public static final List<AzureAiStudioProvider> realtimeRerankProviders = List.of();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we suspect these will be added at some point? If not do we need to have this in code until any are added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be left.
It's used to check if the realtime type is valid for a provider allowing to show more descriptive error
e.g. we try to set realtime as an endpoint_type while creating an inference endpoint:
"The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for clarifying!

}

private AzureAiStudioRerankRequestEntity createRequestEntity() {
var taskSettings = rerankModel.getTaskSettings();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do we need this stored in a separate variable if it's only referenced once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. That's a bit cleaner not to introduce it, corrected. Thx


private static AzureAiStudioRerankTaskSettings createRandom() {
return new AzureAiStudioRerankTaskSettings(
randomFrom(randomFrom(new Boolean[] { null, randomBoolean() })),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the second randomFrom here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well. That's overdo. I removed unnecessary randomFrom. Thank you

? tokenEmbeddingsProviders.contains(provider)
: realtimeEmbeddingsProviders.contains(provider);
}
case RERANK -> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this logic covered in testing anywhere right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. It is covered here:

AzureAiStudioServiceTests#testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, thanks for clarifying.

assertThat(settings, is(AzureAiStudioRerankRequestTaskSettings.EMPTY_SETTINGS));
}

public void testFromMap_ReturnsDoSample() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this test name mean by DoSample?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I referenced AzureAiStudioChatCompletionRequestTaskSettingsTests and missed to rename the test method. Apologies. Renamed


public void testFromMap_ReturnsDoSample() {
final var settings = AzureAiStudioRerankRequestTaskSettings.fromMap(new HashMap<>(Map.of(RETURN_DOCUMENTS_FIELD, true)));
assertThat(settings.returnDocuments(), is(true));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just compare the settings here to an expected settings object similar to the tests above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems better. I did it. Thank you

assertThat(requestMap.get(INPUT), is(List.of(input)));
}

public void testCreateRequest_WithCohereProviderTokenEndpoint_WithTopNParam() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test for creating the request with the other parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Added the test for the return documents parameter. Thanks

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
@PeteGillinElastic PeteGillinElastic added :ml Machine learning and removed needs:triage Requires assignment of a team area label labels Jul 11, 2025
@elasticsearchmachine elasticsearchmachine added the Team:ML Meta label for the ML team label Jul 11, 2025
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

public void testUpdatedTaskSettings_WithAllValues() {
final AzureAiStudioRerankTaskSettings initialSettings = createRandom();
AzureAiStudioRerankTaskSettings newSettings;
int retries = 0;
Copy link
Member

@dan-rubinstein dan-rubinstein Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of running retries which can require multiple loops can we just create the newSettings objects using randomValueOtherThan ourselves:

AzureAiStudioRerankTaskSettings newSettings = new AzureAiStudioRerankTaskSettings(randomValueOtherThan(intialSettings.returnDocuments(), () ->             randomFrom(new Boolean[] { null, randomBoolean() }), randomValueOtherThan(initialSettings.topN(), () ->
            randomFrom(new Integer[] { null, randomNonNegativeInt() })));

Adjust the above based on whether we want to allow null values. The same comment applies to the other tests below. If we find we're reusing this we can create a helper function to do this for us across the various tests.

Copy link
Member

@dan-rubinstein dan-rubinstein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
@dan-rubinstein dan-rubinstein merged commit d06b0c8 into elastic:main Jul 17, 2025
35 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

>enhancement external-contributor Pull request authored by a developer outside the Elasticsearch team :ml Machine learning Team:ML Meta label for the ML team v9.2.0

4 participants