Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/140331.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 140331
summary: "[Inference API] Include rerank in supported tasks for IBM watsonx integration"
area: Inference
type: bug
issues:
- 140328
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
"options": [
"text_embedding",
"chat_completion",
"completion"
"completion",
"rerank"
]
},
"watsonx_inference_id": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public int count() {
protected void writeLenAndValues(BytesStreamOutput out) throws IOException {
// sort the ArrayList variant of the collection prior to serializing it into a binary array
if (values instanceof ArrayList<BytesRef> list) {
list.sort(Comparator.naturalOrder());
list.sort(Comparator.naturalOrder());
}

for (BytesRef value : values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
"openshift_ai",
"test_reranking_service",
"voyageai",
"watsonxai",
"hugging_face",
"amazon_sagemaker",
"elastic"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
Expand Down Expand Up @@ -68,21 +69,28 @@
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID;

public class IbmWatsonxService extends SenderService {

public static final String NAME = "watsonxai";
public class IbmWatsonxService extends SenderService implements RerankingInferenceService {

private static final String SERVICE_NAME = "IBM watsonx";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
TaskType.CHAT_COMPLETION,
TaskType.RERANK
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
"IBM watsonx chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

public static final String NAME = "watsonxai";

// IBM watsonx has a single rerank model with a token limit of 512
// (see https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx#reranker-overview)
// Using 1 token = 0.75 words as a rough estimate, we get 384 words
// allowing for some headroom, we set the window size below 384 words
public static final int RERANK_WINDOW_SIZE = 350;

public IbmWatsonxService(
HttpRequestSender.Factory factory,
ServiceComponents serviceComponents,
Expand Down Expand Up @@ -362,6 +370,11 @@ protected IbmWatsonxActionCreator getActionCreator(Sender sender, ServiceCompone
return new IbmWatsonxActionCreator(getSender(), getServiceComponents());
}

@Override
public int rerankerWindowSize(String modelId) {
return RERANK_WINDOW_SIZE;
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.RerankingInferenceService;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockResponse;
Expand Down Expand Up @@ -82,6 +83,7 @@
import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService.RERANK_WINDOW_SIZE;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
Expand Down Expand Up @@ -943,7 +945,7 @@ public void testGetConfiguration() throws Exception {
{
"service": "watsonxai",
"name": "IBM watsonx",
"task_types": ["text_embedding", "completion", "chat_completion"],
"task_types": ["text_embedding", "rerank", "completion", "chat_completion"],
"configurations": {
"project_id": {
"description": "",
Expand All @@ -952,7 +954,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
},
"model_id": {
"description": "The name of the model to use for the inference task.",
Expand All @@ -961,7 +963,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
},
"api_version": {
"description": "The IBM watsonx API version ID to use.",
Expand All @@ -970,7 +972,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
},
"max_input_tokens": {
"description": "Allows you to specify the maximum number of tokens per input.",
Expand All @@ -988,7 +990,7 @@ public void testGetConfiguration() throws Exception {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
"supported_task_types": ["text_embedding", "rerank", "completion", "chat_completion"]
}
}
}
Expand Down Expand Up @@ -1050,6 +1052,11 @@ public InferenceService createInferenceService() {
return createIbmWatsonxService();
}

@Override
protected void assertRerankerWindowSize(RerankingInferenceService rerankingInferenceService) {
assertThat(rerankingInferenceService.rerankerWindowSize("any model"), is(RERANK_WINDOW_SIZE));
}

private static class IbmWatsonxServiceWithoutAuth extends IbmWatsonxService {
IbmWatsonxServiceWithoutAuth(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents, mockClusterServiceEmpty());
Expand Down