Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
78ab1da
Add Ibm Granite Completion and Chat Completion support
Evgenii-Kazannik May 28, 2025
f92f348
Apply suggestions
Evgenii-Kazannik Jun 10, 2025
510e3c5
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 13, 2025
d6d19be
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 17, 2025
a6eaec6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 23, 2025
9faf6f6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jun 30, 2025
b23bdfb
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
136416d
remove ibm watsonx transport version constant
Evgenii-Kazannik Jul 2, 2025
ff6ccf5
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
80537a4
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
b44bab6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
b1a76c3
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
1bf81ed
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
e70752f
Merge remote-tracking branch 'origin/Add-IBM-Granite-support-for-comp…
Evgenii-Kazannik Jul 2, 2025
b219e72
update transport version
Evgenii-Kazannik Jul 2, 2025
c950380
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
bf882a0
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
f9b086f
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
8e08b9e
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
08ab2f6
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
4ed865c
Merge branch 'main' into Add-IBM-Granite-support-for-completion-and-c…
Evgenii-Kazannik Jul 2, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129146.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129146
summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ static TransportVersion def(int id) {
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);

public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"completion_test_service",
"hugging_face",
"amazon_sagemaker",
"mistral"
"mistral",
"watsonxai"
).toArray()
)
);
Expand All @@ -169,7 +170,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {
"hugging_face",
"amazon_sagemaker",
"googlevertexai",
"mistral"
"mistral",
"watsonxai"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
Expand Down Expand Up @@ -469,6 +470,13 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxChatCompletionServiceSettings.NAME,
IbmWatsonxChatCompletionServiceSettings::new
)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ public DefaultSecretSettings getSecretSettings() {

/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor _
* @param taskSettings _
* @param visitor Interface for creating {@link ExecutableAction} instances for Cohere models.
* @param taskSettings Settings in the request to override the model's defaults
* @return the rerank action
*/
@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.ibmwatsonx;

import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;

import java.util.Locale;

/**
* Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
* Adapts the OpenAI handler to support Watsonx's error schema.
*/
public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {

private static final String WATSONX_ERROR = "watsonx_error";

public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}

@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof IbmWatsonxErrorResponseEntity
? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.ibmwatsonx;

import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;

public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {

/**
* Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "IBM watsonx completions").
* @param parseFunction The function to parse the response.
*/
public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
private static final ResponseHandler HANDLER = createEmbeddingsHandler();

private static ResponseHandler createEmbeddingsHandler() {
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
return new IbmWatsonxResponseHandler("IBM watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
}

private final IbmWatsonxEmbeddingsModel model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@

package org.elasticsearch.xpack.inference.services.ibmwatsonx;

import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Map;
import java.util.Objects;

public abstract class IbmWatsonxModel extends Model {
public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this needs to be a RateLimitGroupingModel?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type needs to be used in GenericRequestManager
which I believe is also going to handle the requests for other tasks in the future


private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;

Expand Down Expand Up @@ -49,4 +50,14 @@ public IbmWatsonxModel(IbmWatsonxModel model, TaskSettings taskSettings) {
public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(this.rateLimitServiceSettings);
}

@Override
public RateLimitSettings rateLimitSettings() {
return this.rateLimitServiceSettings().rateLimitSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {

private static ResponseHandler createIbmWatsonxResponseHandler() {
return new IbmWatsonxResponseHandler(
"ibm watsonx rerank",
"IBM watsonx rerank",
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
Expand All @@ -40,14 +43,18 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
Expand All @@ -56,7 +63,6 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
Expand All @@ -66,8 +72,16 @@ public class IbmWatsonxService extends SenderService {

public static final String NAME = "watsonxai";

private static final String SERVICE_NAME = "IBM Watsonx";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
private static final String SERVICE_NAME = "IBM watsonx";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
"IBM watsonx chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
Expand Down Expand Up @@ -148,6 +162,14 @@ private static IbmWatsonxModel createModel(
secretSettings,
context
);
case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down Expand Up @@ -236,6 +258,11 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_16_0;
}

@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}

@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
Expand Down Expand Up @@ -291,7 +318,24 @@ protected void doUnifiedCompletionInfer(
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof IbmWatsonxChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}

IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model;
var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
overriddenModel,
UNIFIED_CHAT_COMPLETION_HANDLER,
unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
var action = new SenderExecutableAction(getSender(), manager, errorMessage);

action.execute(inputs, timeout, listener);
}

@Override
Expand Down Expand Up @@ -331,7 +375,7 @@ public static InferenceServiceConfiguration get() {

configurationMap.put(
API_VERSION,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM Watsonx API version ID to use.")
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM watsonx API version ID to use.")
.setLabel("API Version")
.setRequired(true)
.setSensitive(false)
Expand Down
Loading
Loading