Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/138632.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138632
summary: Correctly handle empty inputs in `chunkedInfer()`
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,18 @@ public void chunkedInfer(
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
if (supportsChunkedInfer()) {
if (input.isEmpty()) {
chunkedInferListener.onResponse(List.of());
} else {
// a non-null query is not supported and is dropped by all providers
doChunkedInfer(model, input, taskSettings, inputType, timeout, chunkedInferListener);
}
} else {
chunkedInferListener.onFailure(
new UnsupportedOperationException(Strings.format("%s service does not support chunked inference", name()))
);
}
}).addListener(listener);
}

Expand Down Expand Up @@ -183,6 +192,10 @@ protected abstract void doChunkedInfer(
ActionListener<List<ChunkedInference>> listener
);

protected boolean supportsChunkedInfer() {
return true;
}

public void start(Model model, ActionListener<Boolean> listener) {
SubscribableListener.newForked(this::init)
.<Boolean>andThen((doStartListener) -> doStart(model, doStartListener))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
throw new UnsupportedOperationException("AI21 service does not support chunked inference");
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
public InferenceServiceConfiguration getConfiguration() {
return Configuration.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
throw new UnsupportedOperationException("Anthropic service does not support chunked inference");
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_15_0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
listener.onFailure(new ElasticsearchStatusException("Chunked inference is not supported for rerank task", RestStatus.BAD_REQUEST));
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
protected void doUnifiedCompletionInfer(
Model model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,15 @@ protected void doChunkedInfer(
TimeValue timeout,
ActionListener<List<ChunkedInference>> listener
) {
// Should never be called
listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME)));
}

@Override
protected boolean supportsChunkedInfer() {
return false;
}

@Override
public String name() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ public void chunkedInfer(
listener.onFailure(createInvalidModelException(model));
return;
}
if (input.isEmpty()) {
listener.onResponse(List.of());
}
try {
var sageMakerModel = ((SageMakerModel) model).override(taskSettings);
var batchedRequests = new EmbeddingRequestChunker<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -491,6 +492,27 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx
testChunkedInfer(TaskType.SPARSE_EMBEDDING, null);
}

public void testChunkedInfer_noInputs() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
try (var service = new AlibabaCloudSearchService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
var model = createModelForTaskType(randomFrom(TaskType.SPARSE_EMBEDDING, TaskType.TEXT_EMBEDDING), null);

service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputTypeTests.randomWithIngestAndSearch(),
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

}
assertThat(listener.actionGet(TIMEOUT), empty());
}

private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException {
var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -1323,6 +1324,50 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException {
var model = AmazonBedrockEmbeddingsModelTests.createModel(
"id",
"region",
"model",
AmazonBedrockProvider.AMAZONTITAN,
null,
"access",
"secret"
);

var sender = createMockSender();
var factory = mock(HttpRequestSender.Factory.class);
when(factory.createSender()).thenReturn(sender);

var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory(
ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY),
mockClusterServiceEmpty()
);

try (
var service = new AmazonBedrockService(
factory,
amazonBedrockFactory,
createWithEmptySettings(threadPool),
mockClusterServiceEmpty()
)
) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
}
}

private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOException {
var sender = createMockSender();
var factory = mock(HttpRequestSender.Factory.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import static org.elasticsearch.xpack.inference.services.azureaistudio.request.AzureAiStudioRequestFields.API_KEY_HEADER;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -1294,6 +1295,27 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException {
var model = AzureAiStudioEmbeddingsModelTests.createModel(
"id",
getUrl(webServer),
AzureAiStudioProvider.OPENAI,
AzureAiStudioEndpointType.TOKEN,
"apikey"
);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
List<ChunkInferenceInput> input = List.of();
service.chunkedInfer(model, null, input, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}

private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -953,6 +954,32 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException, URISyn
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException, URISyntaxException {
var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", null, "apikey", null, "id");

var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {

model.setUri(new URI(getUrl(webServer)));
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
List<ChunkInferenceInput> input = List.of();
service.chunkedInfer(
model,
null,
input,
new HashMap<>(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}

private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOException, URISyntaxException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -1377,6 +1378,28 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
testChunkedInfer(model);
}

public void testChunkedInfer_noInputs() throws IOException {
var model = CohereEmbeddingsModelTests.createModel(getUrl(webServer), "secret", 1024, "model", null);
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool), mockClusterServiceEmpty())) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}

private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import static org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser.RERANK_PARSER_SCORE;
import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH;
import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -822,4 +823,29 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
assertThat(requestMap.get("input"), is(List.of("a")));
}
}

public void testChunkedInfer_noInputs() throws IOException {
var model = createInternalEmbeddingModel(
new DenseEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
getUrl(webServer)
);

try (var service = createService(threadPool, clientManager)) {

PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(
model,
null,
List.of(),
new HashMap<>(),
InputType.INTERNAL_INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
listener
);

var results = listener.actionGet(TIMEOUT);
assertThat(results, empty());
assertThat(webServer.requests(), empty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
Expand Down Expand Up @@ -349,12 +351,21 @@ public void testUnifiedCompletionMalformedError() throws Exception {
}}""");
}

public void testDoChunkedInferAlwaysFails() throws IOException {
public void testChunkedInferFails() throws IOException {
try (var service = createService()) {
service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> {
assertThat(e, isA(UnsupportedOperationException.class));
assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion"));
}));
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(mock(), null, List.of(new ChunkInferenceInput("a")), Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
var exception = expectThrows(UnsupportedOperationException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("deepseek service does not support chunked inference"));
}
}

public void testChunkedInferFails_noInputs() throws IOException {
try (var service = createService()) {
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
service.chunkedInfer(mock(), null, List.of(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener);
var exception = expectThrows(UnsupportedOperationException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.getMessage(), is("deepseek service does not support chunked inference"));
}
}

Expand Down
Loading