Skip to content

Commit cdb2d7a

Browse files
authored
[9.2] [Inference API] Support chunking settings for sparse embeddings in custom service (#138776) (#145419)
1 parent 175b3aa commit cdb2d7a

3 files changed

Lines changed: 291 additions & 9 deletions

File tree

‎docs/changelog/138776.yaml‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 138776
2+
summary: "[Inference API] Support chunking settings for sparse embeddings in custom\
3+
\ service"
4+
area: Machine Learning
5+
type: bug
6+
issues: []

‎x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public void parseRequestConfig(
108108
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
109109

110110
ChunkingSettings chunkingSettings = null;
111-
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
111+
if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) {
112112
chunkingSettings = ChunkingSettingsBuilder.fromMap(
113113
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
114114
);
@@ -241,7 +241,7 @@ public CustomModel parsePersistedConfigWithSecrets(
241241
}
242242

243243
private static ChunkingSettings extractPersistentChunkingSettings(Map<String, Object> config, TaskType taskType) {
244-
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
244+
if (TaskType.SPARSE_EMBEDDING.equals(taskType) || TaskType.TEXT_EMBEDDING.equals(taskType)) {
245245
/*
246246
* There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context.
247247
* For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will

‎x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java‎

Lines changed: 283 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.inference.ChunkInferenceInput;
1717
import org.elasticsearch.inference.ChunkedInference;
1818
import org.elasticsearch.inference.ChunkingSettings;
19+
import org.elasticsearch.inference.ChunkingStrategy;
1920
import org.elasticsearch.inference.InferenceServiceResults;
2021
import org.elasticsearch.inference.InputType;
2122
import org.elasticsearch.inference.Model;
@@ -31,7 +32,9 @@
3132
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
3233
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
3334
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
35+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsOptions;
3436
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
37+
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
3538
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
3639
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
3740
import org.elasticsearch.xpack.inference.services.AbstractInferenceServiceTests;
@@ -64,6 +67,7 @@
6467
import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_TOKEN_PATH;
6568
import static org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser.SPARSE_EMBEDDING_WEIGHT_PATH;
6669
import static org.hamcrest.Matchers.empty;
70+
import static org.hamcrest.Matchers.equalTo;
6771
import static org.hamcrest.Matchers.hasSize;
6872
import static org.hamcrest.Matchers.instanceOf;
6973
import static org.hamcrest.Matchers.is;
@@ -305,7 +309,12 @@ private static CustomModel createInternalEmbeddingModel(
305309
);
306310
}
307311

308-
private static CustomModel createCustomModel(TaskType taskType, CustomResponseParser customResponseParser, String url) {
312+
private static CustomModel createCustomModel(
313+
TaskType taskType,
314+
CustomResponseParser customResponseParser,
315+
String url,
316+
@Nullable ChunkingSettings chunkingSettings
317+
) {
309318
return new CustomModel(
310319
"model_id",
311320
taskType,
@@ -320,7 +329,8 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa
320329
new RateLimitSettings(10_000)
321330
),
322331
new CustomTaskSettings(Map.of("key", "test_value")),
323-
new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray())))
332+
new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))),
333+
chunkingSettings
324334
);
325335
}
326336

@@ -467,7 +477,8 @@ public void testInfer_HandlesRerankRequest_Cohere_Format() throws IOException {
467477
var model = createCustomModel(
468478
TaskType.RERANK,
469479
new RerankResponseParser("$.results[*].relevance_score", "$.results[*].index", "$.results[*].document.text"),
470-
getUrl(webServer)
480+
getUrl(webServer),
481+
null
471482
);
472483

473484
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -538,7 +549,8 @@ public void testInfer_HandlesCompletionRequest_OpenAI_Format() throws IOExceptio
538549
var model = createCustomModel(
539550
TaskType.COMPLETION,
540551
new CompletionResponseParser("$.choices[*].message.content"),
541-
getUrl(webServer)
552+
getUrl(webServer),
553+
null
542554
);
543555

544556
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -603,7 +615,8 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx
603615
"$.result.sparse_embeddings[*].embedding[*].tokenId",
604616
"$.result.sparse_embeddings[*].embedding[*].weight"
605617
),
606-
getUrl(webServer)
618+
getUrl(webServer),
619+
null
607620
);
608621

609622
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -674,7 +687,45 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot
674687
}
675688
}
676689

677-
public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
690+
public void testParseRequestConfig_DoesNotThrow_WhenChunkingSettingsArePresentForSparseEmbeddings() throws IOException {
691+
try (var service = createService(threadPool, clientManager)) {
692+
Map<String, Object> serviceSettingsMap = new HashMap<>(
693+
Map.of(
694+
CustomServiceSettings.URL,
695+
"http://www.abc.com",
696+
CustomServiceSettings.HEADERS,
697+
Map.of("key", "value"),
698+
QueryParameters.QUERY_PARAMETERS,
699+
List.of(List.of("key", "value")),
700+
CustomServiceSettings.REQUEST,
701+
"request body",
702+
CustomServiceSettings.RESPONSE,
703+
new HashMap<>(Map.of(CustomServiceSettings.JSON_PARSER, createResponseParserMap(TaskType.SPARSE_EMBEDDING)))
704+
)
705+
);
706+
707+
Map<String, Object> chunkingSettingsMap = new HashMap<>();
708+
chunkingSettingsMap.put(ChunkingSettingsOptions.STRATEGY.toString(), "sentence");
709+
chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), 40);
710+
chunkingSettingsMap.put(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), 0);
711+
712+
var config = getRequestConfigMap(serviceSettingsMap, createTaskSettingsMap(), chunkingSettingsMap, createSecretSettingsMap());
713+
var listener = new PlainActionFuture<Model>();
714+
715+
service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, listener);
716+
717+
// Check chunking settings
718+
CustomModel model = (CustomModel) listener.actionGet(TIMEOUT);
719+
ChunkingSettings chunkingSettings = model.getConfigurations().getChunkingSettings();
720+
721+
assertThat(chunkingSettings, instanceOf(SentenceBoundaryChunkingSettings.class));
722+
assertThat(chunkingSettings.getChunkingStrategy(), equalTo(ChunkingStrategy.SENTENCE));
723+
assertThat(chunkingSettings.asMap().get(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString()), equalTo(40));
724+
assertThat(chunkingSettings.asMap().get(ChunkingSettingsOptions.SENTENCE_OVERLAP.toString()), equalTo(0));
725+
}
726+
}
727+
728+
public void testChunkedInfer_DenseEmbeddings_ChunkingSettingsSet() throws IOException {
678729
var model = createInternalEmbeddingModel(
679730
SimilarityMeasure.DOT_PRODUCT,
680731
new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
@@ -761,7 +812,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
761812
}
762813
}
763814

764-
public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
815+
public void testChunkedInfer_DenseEmbeddings_ChunkingSettingsNotSet() throws IOException {
765816
var model = createInternalEmbeddingModel(
766817
new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),
767818
getUrl(webServer)
@@ -824,6 +875,231 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
824875
}
825876
}
826877

878+
public void testChunkedInfer_SparseEmbeddings_ChunkingSettingsSet() throws IOException {
879+
var model = createCustomModel(
880+
TaskType.SPARSE_EMBEDDING,
881+
new SparseEmbeddingResponseParser(
882+
"$.result.sparse_embeddings[*].embedding[*].tokenId",
883+
"$.result.sparse_embeddings[*].embedding[*].weight"
884+
),
885+
getUrl(webServer),
886+
ChunkingSettingsTests.createRandomChunkingSettings()
887+
);
888+
889+
String responseJson = """
890+
{
891+
"request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
892+
"latency": 22,
893+
"usage": {
894+
"token_count": 11
895+
},
896+
"result": {
897+
"sparse_embeddings": [
898+
{
899+
"index": 0,
900+
"embedding": [
901+
{
902+
"tokenId": 6,
903+
"weight": 0.101
904+
},
905+
{
906+
"tokenId": 163040,
907+
"weight": 0.28417
908+
}
909+
]
910+
},
911+
{
912+
"index": 1,
913+
"embedding": [
914+
{
915+
"tokenId": 4,
916+
"weight": 0.201
917+
},
918+
{
919+
"tokenId": 153040,
920+
"weight": 0.24417
921+
}
922+
]
923+
}
924+
]
925+
}
926+
}
927+
""";
928+
929+
try (var service = createService(threadPool, clientManager)) {
930+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
931+
932+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
933+
service.chunkedInfer(
934+
model,
935+
null,
936+
List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")),
937+
new HashMap<>(),
938+
InputType.INTERNAL_INGEST,
939+
InferenceAction.Request.DEFAULT_TIMEOUT,
940+
listener
941+
);
942+
943+
var results = listener.actionGet(TIMEOUT);
944+
assertThat(results, hasSize(2));
945+
946+
// Check first embedding
947+
{
948+
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
949+
var sparseEmbeddingResult = (ChunkedInferenceEmbedding) results.get(0);
950+
assertThat(sparseEmbeddingResult.chunks(), hasSize(1));
951+
assertEquals(new ChunkedInference.TextOffset(0, 1), sparseEmbeddingResult.chunks().get(0).offset());
952+
assertThat(sparseEmbeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(SparseEmbeddingResults.Embedding.class));
953+
assertThat(
954+
((SparseEmbeddingResults.Embedding) sparseEmbeddingResult.chunks().get(0).embedding()),
955+
equalTo(
956+
new SparseEmbeddingResults.Embedding(
957+
List.of(new WeightedToken("6", 0.101f), new WeightedToken("163040", 0.28417f)),
958+
false
959+
)
960+
)
961+
);
962+
}
963+
964+
// Check second embedding
965+
{
966+
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
967+
var sparseEmbeddingResult = (ChunkedInferenceEmbedding) results.get(1);
968+
assertThat(sparseEmbeddingResult.chunks(), hasSize(1));
969+
assertEquals(new ChunkedInference.TextOffset(0, 2), sparseEmbeddingResult.chunks().get(0).offset());
970+
assertThat(sparseEmbeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(SparseEmbeddingResults.Embedding.class));
971+
assertThat(
972+
((SparseEmbeddingResults.Embedding) sparseEmbeddingResult.chunks().get(0).embedding()),
973+
equalTo(
974+
new SparseEmbeddingResults.Embedding(
975+
List.of(new WeightedToken("4", 0.201f), new WeightedToken("153040", 0.24417f)),
976+
false
977+
)
978+
)
979+
);
980+
}
981+
982+
assertThat(webServer.requests(), hasSize(1));
983+
984+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
985+
assertThat(requestMap.size(), is(1));
986+
assertThat(requestMap.get("input"), is(List.of("a", "bb")));
987+
}
988+
}
989+
990+
public void testChunkedInfer_SparseEmbeddings_ChunkingSettingsNotSet() throws IOException {
991+
var model = createCustomModel(
992+
TaskType.SPARSE_EMBEDDING,
993+
new SparseEmbeddingResponseParser(
994+
"$.result.sparse_embeddings[*].embedding[*].tokenId",
995+
"$.result.sparse_embeddings[*].embedding[*].weight"
996+
),
997+
getUrl(webServer),
998+
null // chunking not explicitly set
999+
);
1000+
1001+
String responseJson = """
1002+
{
1003+
"request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
1004+
"latency": 22,
1005+
"usage": {
1006+
"token_count": 11
1007+
},
1008+
"result": {
1009+
"sparse_embeddings": [
1010+
{
1011+
"index": 0,
1012+
"embedding": [
1013+
{
1014+
"tokenId": 6,
1015+
"weight": 0.101
1016+
},
1017+
{
1018+
"tokenId": 163040,
1019+
"weight": 0.28417
1020+
}
1021+
]
1022+
},
1023+
{
1024+
"index": 1,
1025+
"embedding": [
1026+
{
1027+
"tokenId": 4,
1028+
"weight": 0.201
1029+
},
1030+
{
1031+
"tokenId": 153040,
1032+
"weight": 0.24417
1033+
}
1034+
]
1035+
}
1036+
]
1037+
}
1038+
}
1039+
""";
1040+
1041+
try (var service = createService(threadPool, clientManager)) {
1042+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
1043+
1044+
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
1045+
service.chunkedInfer(
1046+
model,
1047+
null,
1048+
List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")),
1049+
new HashMap<>(),
1050+
InputType.INTERNAL_INGEST,
1051+
InferenceAction.Request.DEFAULT_TIMEOUT,
1052+
listener
1053+
);
1054+
1055+
var results = listener.actionGet(TIMEOUT);
1056+
assertThat(results, hasSize(2));
1057+
1058+
// Check first embedding
1059+
{
1060+
assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
1061+
var sparseEmbeddingResult = (ChunkedInferenceEmbedding) results.get(0);
1062+
assertThat(sparseEmbeddingResult.chunks(), hasSize(1));
1063+
assertEquals(new ChunkedInference.TextOffset(0, 1), sparseEmbeddingResult.chunks().get(0).offset());
1064+
assertThat(sparseEmbeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(SparseEmbeddingResults.Embedding.class));
1065+
assertThat(
1066+
((SparseEmbeddingResults.Embedding) sparseEmbeddingResult.chunks().get(0).embedding()),
1067+
equalTo(
1068+
new SparseEmbeddingResults.Embedding(
1069+
List.of(new WeightedToken("6", 0.101f), new WeightedToken("163040", 0.28417f)),
1070+
false
1071+
)
1072+
)
1073+
);
1074+
}
1075+
1076+
// Check second embedding
1077+
{
1078+
assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedInferenceEmbedding.class));
1079+
var sparseEmbeddingResult = (ChunkedInferenceEmbedding) results.get(1);
1080+
assertThat(sparseEmbeddingResult.chunks(), hasSize(1));
1081+
assertEquals(new ChunkedInference.TextOffset(0, 2), sparseEmbeddingResult.chunks().get(0).offset());
1082+
assertThat(sparseEmbeddingResult.chunks().get(0).embedding(), Matchers.instanceOf(SparseEmbeddingResults.Embedding.class));
1083+
assertThat(
1084+
((SparseEmbeddingResults.Embedding) sparseEmbeddingResult.chunks().get(0).embedding()),
1085+
equalTo(
1086+
new SparseEmbeddingResults.Embedding(
1087+
List.of(new WeightedToken("4", 0.201f), new WeightedToken("153040", 0.24417f)),
1088+
false
1089+
)
1090+
)
1091+
);
1092+
}
1093+
1094+
assertThat(webServer.requests(), hasSize(1));
1095+
1096+
// Check request
1097+
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
1098+
assertThat(requestMap.size(), is(1));
1099+
assertThat(requestMap.get("input"), is(List.of("a", "bb")));
1100+
}
1101+
}
1102+
8271103
public void testChunkedInfer_noInputs() throws IOException {
8281104
var model = createInternalEmbeddingModel(
8291105
new TextEmbeddingResponseParser("$.data[*].embedding", CustomServiceEmbeddingType.FLOAT),

0 commit comments

Comments
 (0)