1616import org .elasticsearch .inference .ChunkInferenceInput ;
1717import org .elasticsearch .inference .ChunkedInference ;
1818import org .elasticsearch .inference .ChunkingSettings ;
19+ import org .elasticsearch .inference .ChunkingStrategy ;
1920import org .elasticsearch .inference .InferenceServiceResults ;
2021import org .elasticsearch .inference .InputType ;
2122import org .elasticsearch .inference .Model ;
3132import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
3233import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
3334import org .elasticsearch .xpack .core .inference .results .TextEmbeddingFloatResults ;
35+ import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsOptions ;
3436import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsTests ;
37+ import org .elasticsearch .xpack .inference .chunking .SentenceBoundaryChunkingSettings ;
3538import org .elasticsearch .xpack .inference .external .http .HttpClientManager ;
3639import org .elasticsearch .xpack .inference .external .http .sender .HttpRequestSenderTests ;
3740import org .elasticsearch .xpack .inference .services .AbstractInferenceServiceTests ;
6467import static org .elasticsearch .xpack .inference .services .custom .response .SparseEmbeddingResponseParser .SPARSE_EMBEDDING_TOKEN_PATH ;
6568import static org .elasticsearch .xpack .inference .services .custom .response .SparseEmbeddingResponseParser .SPARSE_EMBEDDING_WEIGHT_PATH ;
6669import static org .hamcrest .Matchers .empty ;
70+ import static org .hamcrest .Matchers .equalTo ;
6771import static org .hamcrest .Matchers .hasSize ;
6872import static org .hamcrest .Matchers .instanceOf ;
6973import static org .hamcrest .Matchers .is ;
@@ -305,7 +309,12 @@ private static CustomModel createInternalEmbeddingModel(
305309 );
306310 }
307311
308- private static CustomModel createCustomModel (TaskType taskType , CustomResponseParser customResponseParser , String url ) {
312+ private static CustomModel createCustomModel (
313+ TaskType taskType ,
314+ CustomResponseParser customResponseParser ,
315+ String url ,
316+ @ Nullable ChunkingSettings chunkingSettings
317+ ) {
309318 return new CustomModel (
310319 "model_id" ,
311320 taskType ,
@@ -320,7 +329,8 @@ private static CustomModel createCustomModel(TaskType taskType, CustomResponsePa
320329 new RateLimitSettings (10_000 )
321330 ),
322331 new CustomTaskSettings (Map .of ("key" , "test_value" )),
323- new CustomSecretSettings (Map .of ("test_key" , new SecureString ("test_value" .toCharArray ())))
332+ new CustomSecretSettings (Map .of ("test_key" , new SecureString ("test_value" .toCharArray ()))),
333+ chunkingSettings
324334 );
325335 }
326336
@@ -467,7 +477,8 @@ public void testInfer_HandlesRerankRequest_Cohere_Format() throws IOException {
467477 var model = createCustomModel (
468478 TaskType .RERANK ,
469479 new RerankResponseParser ("$.results[*].relevance_score" , "$.results[*].index" , "$.results[*].document.text" ),
470- getUrl (webServer )
480+ getUrl (webServer ),
481+ null
471482 );
472483
473484 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
@@ -538,7 +549,8 @@ public void testInfer_HandlesCompletionRequest_OpenAI_Format() throws IOExceptio
538549 var model = createCustomModel (
539550 TaskType .COMPLETION ,
540551 new CompletionResponseParser ("$.choices[*].message.content" ),
541- getUrl (webServer )
552+ getUrl (webServer ),
553+ null
542554 );
543555
544556 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
@@ -603,7 +615,8 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx
603615 "$.result.sparse_embeddings[*].embedding[*].tokenId" ,
604616 "$.result.sparse_embeddings[*].embedding[*].weight"
605617 ),
606- getUrl (webServer )
618+ getUrl (webServer ),
619+ null
607620 );
608621
609622 PlainActionFuture <InferenceServiceResults > listener = new PlainActionFuture <>();
@@ -674,7 +687,45 @@ public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNot
674687 }
675688 }
676689
677- public void testChunkedInfer_ChunkingSettingsSet () throws IOException {
690+ public void testParseRequestConfig_DoesNotThrow_WhenChunkingSettingsArePresentForSparseEmbeddings () throws IOException {
691+ try (var service = createService (threadPool , clientManager )) {
692+ Map <String , Object > serviceSettingsMap = new HashMap <>(
693+ Map .of (
694+ CustomServiceSettings .URL ,
695+ "http://www.abc.com" ,
696+ CustomServiceSettings .HEADERS ,
697+ Map .of ("key" , "value" ),
698+ QueryParameters .QUERY_PARAMETERS ,
699+ List .of (List .of ("key" , "value" )),
700+ CustomServiceSettings .REQUEST ,
701+ "request body" ,
702+ CustomServiceSettings .RESPONSE ,
703+ new HashMap <>(Map .of (CustomServiceSettings .JSON_PARSER , createResponseParserMap (TaskType .SPARSE_EMBEDDING )))
704+ )
705+ );
706+
707+ Map <String , Object > chunkingSettingsMap = new HashMap <>();
708+ chunkingSettingsMap .put (ChunkingSettingsOptions .STRATEGY .toString (), "sentence" );
709+ chunkingSettingsMap .put (ChunkingSettingsOptions .MAX_CHUNK_SIZE .toString (), 40 );
710+ chunkingSettingsMap .put (ChunkingSettingsOptions .SENTENCE_OVERLAP .toString (), 0 );
711+
712+ var config = getRequestConfigMap (serviceSettingsMap , createTaskSettingsMap (), chunkingSettingsMap , createSecretSettingsMap ());
713+ var listener = new PlainActionFuture <Model >();
714+
715+ service .parseRequestConfig ("id" , TaskType .SPARSE_EMBEDDING , config , listener );
716+
717+ // Check chunking settings
718+ CustomModel model = (CustomModel ) listener .actionGet (TIMEOUT );
719+ ChunkingSettings chunkingSettings = model .getConfigurations ().getChunkingSettings ();
720+
721+ assertThat (chunkingSettings , instanceOf (SentenceBoundaryChunkingSettings .class ));
722+ assertThat (chunkingSettings .getChunkingStrategy (), equalTo (ChunkingStrategy .SENTENCE ));
723+ assertThat (chunkingSettings .asMap ().get (ChunkingSettingsOptions .MAX_CHUNK_SIZE .toString ()), equalTo (40 ));
724+ assertThat (chunkingSettings .asMap ().get (ChunkingSettingsOptions .SENTENCE_OVERLAP .toString ()), equalTo (0 ));
725+ }
726+ }
727+
728+ public void testChunkedInfer_DenseEmbeddings_ChunkingSettingsSet () throws IOException {
678729 var model = createInternalEmbeddingModel (
679730 SimilarityMeasure .DOT_PRODUCT ,
680731 new TextEmbeddingResponseParser ("$.data[*].embedding" , CustomServiceEmbeddingType .FLOAT ),
@@ -761,7 +812,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
761812 }
762813 }
763814
764- public void testChunkedInfer_ChunkingSettingsNotSet () throws IOException {
815+ public void testChunkedInfer_DenseEmbeddings_ChunkingSettingsNotSet () throws IOException {
765816 var model = createInternalEmbeddingModel (
766817 new TextEmbeddingResponseParser ("$.data[*].embedding" , CustomServiceEmbeddingType .FLOAT ),
767818 getUrl (webServer )
@@ -824,6 +875,231 @@ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
824875 }
825876 }
826877
878+ public void testChunkedInfer_SparseEmbeddings_ChunkingSettingsSet () throws IOException {
879+ var model = createCustomModel (
880+ TaskType .SPARSE_EMBEDDING ,
881+ new SparseEmbeddingResponseParser (
882+ "$.result.sparse_embeddings[*].embedding[*].tokenId" ,
883+ "$.result.sparse_embeddings[*].embedding[*].weight"
884+ ),
885+ getUrl (webServer ),
886+ ChunkingSettingsTests .createRandomChunkingSettings ()
887+ );
888+
889+ String responseJson = """
890+ {
891+ "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
892+ "latency": 22,
893+ "usage": {
894+ "token_count": 11
895+ },
896+ "result": {
897+ "sparse_embeddings": [
898+ {
899+ "index": 0,
900+ "embedding": [
901+ {
902+ "tokenId": 6,
903+ "weight": 0.101
904+ },
905+ {
906+ "tokenId": 163040,
907+ "weight": 0.28417
908+ }
909+ ]
910+ },
911+ {
912+ "index": 1,
913+ "embedding": [
914+ {
915+ "tokenId": 4,
916+ "weight": 0.201
917+ },
918+ {
919+ "tokenId": 153040,
920+ "weight": 0.24417
921+ }
922+ ]
923+ }
924+ ]
925+ }
926+ }
927+ """ ;
928+
929+ try (var service = createService (threadPool , clientManager )) {
930+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
931+
932+ PlainActionFuture <List <ChunkedInference >> listener = new PlainActionFuture <>();
933+ service .chunkedInfer (
934+ model ,
935+ null ,
936+ List .of (new ChunkInferenceInput ("a" ), new ChunkInferenceInput ("bb" )),
937+ new HashMap <>(),
938+ InputType .INTERNAL_INGEST ,
939+ InferenceAction .Request .DEFAULT_TIMEOUT ,
940+ listener
941+ );
942+
943+ var results = listener .actionGet (TIMEOUT );
944+ assertThat (results , hasSize (2 ));
945+
946+ // Check first embedding
947+ {
948+ assertThat (results .get (0 ), CoreMatchers .instanceOf (ChunkedInferenceEmbedding .class ));
949+ var sparseEmbeddingResult = (ChunkedInferenceEmbedding ) results .get (0 );
950+ assertThat (sparseEmbeddingResult .chunks (), hasSize (1 ));
951+ assertEquals (new ChunkedInference .TextOffset (0 , 1 ), sparseEmbeddingResult .chunks ().get (0 ).offset ());
952+ assertThat (sparseEmbeddingResult .chunks ().get (0 ).embedding (), Matchers .instanceOf (SparseEmbeddingResults .Embedding .class ));
953+ assertThat (
954+ ((SparseEmbeddingResults .Embedding ) sparseEmbeddingResult .chunks ().get (0 ).embedding ()),
955+ equalTo (
956+ new SparseEmbeddingResults .Embedding (
957+ List .of (new WeightedToken ("6" , 0.101f ), new WeightedToken ("163040" , 0.28417f )),
958+ false
959+ )
960+ )
961+ );
962+ }
963+
964+ // Check second embedding
965+ {
966+ assertThat (results .get (1 ), CoreMatchers .instanceOf (ChunkedInferenceEmbedding .class ));
967+ var sparseEmbeddingResult = (ChunkedInferenceEmbedding ) results .get (1 );
968+ assertThat (sparseEmbeddingResult .chunks (), hasSize (1 ));
969+ assertEquals (new ChunkedInference .TextOffset (0 , 2 ), sparseEmbeddingResult .chunks ().get (0 ).offset ());
970+ assertThat (sparseEmbeddingResult .chunks ().get (0 ).embedding (), Matchers .instanceOf (SparseEmbeddingResults .Embedding .class ));
971+ assertThat (
972+ ((SparseEmbeddingResults .Embedding ) sparseEmbeddingResult .chunks ().get (0 ).embedding ()),
973+ equalTo (
974+ new SparseEmbeddingResults .Embedding (
975+ List .of (new WeightedToken ("4" , 0.201f ), new WeightedToken ("153040" , 0.24417f )),
976+ false
977+ )
978+ )
979+ );
980+ }
981+
982+ assertThat (webServer .requests (), hasSize (1 ));
983+
984+ var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
985+ assertThat (requestMap .size (), is (1 ));
986+ assertThat (requestMap .get ("input" ), is (List .of ("a" , "bb" )));
987+ }
988+ }
989+
990+ public void testChunkedInfer_SparseEmbeddings_ChunkingSettingsNotSet () throws IOException {
991+ var model = createCustomModel (
992+ TaskType .SPARSE_EMBEDDING ,
993+ new SparseEmbeddingResponseParser (
994+ "$.result.sparse_embeddings[*].embedding[*].tokenId" ,
995+ "$.result.sparse_embeddings[*].embedding[*].weight"
996+ ),
997+ getUrl (webServer ),
998+ null // chunking not explicitly set
999+ );
1000+
1001+ String responseJson = """
1002+ {
1003+ "request_id": "75C50B5B-E79E-4930-****-F48DBB392231",
1004+ "latency": 22,
1005+ "usage": {
1006+ "token_count": 11
1007+ },
1008+ "result": {
1009+ "sparse_embeddings": [
1010+ {
1011+ "index": 0,
1012+ "embedding": [
1013+ {
1014+ "tokenId": 6,
1015+ "weight": 0.101
1016+ },
1017+ {
1018+ "tokenId": 163040,
1019+ "weight": 0.28417
1020+ }
1021+ ]
1022+ },
1023+ {
1024+ "index": 1,
1025+ "embedding": [
1026+ {
1027+ "tokenId": 4,
1028+ "weight": 0.201
1029+ },
1030+ {
1031+ "tokenId": 153040,
1032+ "weight": 0.24417
1033+ }
1034+ ]
1035+ }
1036+ ]
1037+ }
1038+ }
1039+ """ ;
1040+
1041+ try (var service = createService (threadPool , clientManager )) {
1042+ webServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (responseJson ));
1043+
1044+ PlainActionFuture <List <ChunkedInference >> listener = new PlainActionFuture <>();
1045+ service .chunkedInfer (
1046+ model ,
1047+ null ,
1048+ List .of (new ChunkInferenceInput ("a" ), new ChunkInferenceInput ("bb" )),
1049+ new HashMap <>(),
1050+ InputType .INTERNAL_INGEST ,
1051+ InferenceAction .Request .DEFAULT_TIMEOUT ,
1052+ listener
1053+ );
1054+
1055+ var results = listener .actionGet (TIMEOUT );
1056+ assertThat (results , hasSize (2 ));
1057+
1058+ // Check first embedding
1059+ {
1060+ assertThat (results .get (0 ), CoreMatchers .instanceOf (ChunkedInferenceEmbedding .class ));
1061+ var sparseEmbeddingResult = (ChunkedInferenceEmbedding ) results .get (0 );
1062+ assertThat (sparseEmbeddingResult .chunks (), hasSize (1 ));
1063+ assertEquals (new ChunkedInference .TextOffset (0 , 1 ), sparseEmbeddingResult .chunks ().get (0 ).offset ());
1064+ assertThat (sparseEmbeddingResult .chunks ().get (0 ).embedding (), Matchers .instanceOf (SparseEmbeddingResults .Embedding .class ));
1065+ assertThat (
1066+ ((SparseEmbeddingResults .Embedding ) sparseEmbeddingResult .chunks ().get (0 ).embedding ()),
1067+ equalTo (
1068+ new SparseEmbeddingResults .Embedding (
1069+ List .of (new WeightedToken ("6" , 0.101f ), new WeightedToken ("163040" , 0.28417f )),
1070+ false
1071+ )
1072+ )
1073+ );
1074+ }
1075+
1076+ // Check second embedding
1077+ {
1078+ assertThat (results .get (1 ), CoreMatchers .instanceOf (ChunkedInferenceEmbedding .class ));
1079+ var sparseEmbeddingResult = (ChunkedInferenceEmbedding ) results .get (1 );
1080+ assertThat (sparseEmbeddingResult .chunks (), hasSize (1 ));
1081+ assertEquals (new ChunkedInference .TextOffset (0 , 2 ), sparseEmbeddingResult .chunks ().get (0 ).offset ());
1082+ assertThat (sparseEmbeddingResult .chunks ().get (0 ).embedding (), Matchers .instanceOf (SparseEmbeddingResults .Embedding .class ));
1083+ assertThat (
1084+ ((SparseEmbeddingResults .Embedding ) sparseEmbeddingResult .chunks ().get (0 ).embedding ()),
1085+ equalTo (
1086+ new SparseEmbeddingResults .Embedding (
1087+ List .of (new WeightedToken ("4" , 0.201f ), new WeightedToken ("153040" , 0.24417f )),
1088+ false
1089+ )
1090+ )
1091+ );
1092+ }
1093+
1094+ assertThat (webServer .requests (), hasSize (1 ));
1095+
1096+ // Check request
1097+ var requestMap = entityAsMap (webServer .requests ().get (0 ).getBody ());
1098+ assertThat (requestMap .size (), is (1 ));
1099+ assertThat (requestMap .get ("input" ), is (List .of ("a" , "bb" )));
1100+ }
1101+ }
1102+
8271103 public void testChunkedInfer_noInputs () throws IOException {
8281104 var model = createInternalEmbeddingModel (
8291105 new TextEmbeddingResponseParser ("$.data[*].embedding" , CustomServiceEmbeddingType .FLOAT ),
0 commit comments