Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
ccd64ae
fix boosting for knn
Samiul-TheSoccerFan Jun 11, 2025
9338cd5
Fixing for match query
Samiul-TheSoccerFan Jun 11, 2025
370931d
fixing for match subquery
Samiul-TheSoccerFan Jun 11, 2025
b85abda
fix for sparse vector query boost
Samiul-TheSoccerFan Jun 11, 2025
5db2686
fix linting issues
Samiul-TheSoccerFan Jun 11, 2025
2ce691e
Update docs/changelog/129282.yaml
Samiul-TheSoccerFan Jun 11, 2025
4100200
update changelog
Samiul-TheSoccerFan Jun 11, 2025
3406ae1
Copy constructor with match query
Samiul-TheSoccerFan Jun 12, 2025
d07952a
util function to create sparseVectorBuilder for sparse query
Samiul-TheSoccerFan Jun 12, 2025
f133632
util function for knn query to support boost
Samiul-TheSoccerFan Jun 12, 2025
a9048f0
adding unit tests for all intercepted query terms
Samiul-TheSoccerFan Jun 12, 2025
5a1dab9
Adding yaml test for match,sparse, and knn
Samiul-TheSoccerFan Jun 13, 2025
6cef441
Adding queryname support for nested query
Samiul-TheSoccerFan Jun 13, 2025
faa35ea
fix code styles
Samiul-TheSoccerFan Jun 13, 2025
675fb22
merge from main
Samiul-TheSoccerFan Jun 13, 2025
13e791e
Fix failed yaml tests
Samiul-TheSoccerFan Jun 13, 2025
3a5a30f
Update docs/changelog/129282.yaml
Samiul-TheSoccerFan Jun 13, 2025
016e448
update yaml tests to expand test scenarios
Samiul-TheSoccerFan Jun 16, 2025
70b228e
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jun 16, 2025
d5e7caa
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jun 16, 2025
efcf9c4
Updating knn to copy constructor
Samiul-TheSoccerFan Jun 25, 2025
00eb6ad
merge from main
Samiul-TheSoccerFan Jun 25, 2025
f449299
adding yaml tests for multiple indices
Samiul-TheSoccerFan Jun 27, 2025
6db0abf
refactoring match query to adjust boost and queryname and move to cop…
Samiul-TheSoccerFan Jun 27, 2025
daf2cb4
refactoring sparse query to adjust boost and queryname and move to co…
Samiul-TheSoccerFan Jun 27, 2025
b88b077
[CI] Auto commit changes from spotless
Jun 27, 2025
9e725cb
Refactor sparse vector to adjust boost and queryname in the top level
Samiul-TheSoccerFan Jul 2, 2025
651ee2b
Refactor knn vector to adjust boost and queryname in the top level
Samiul-TheSoccerFan Jul 2, 2025
a356b44
merge from main
Samiul-TheSoccerFan Jul 2, 2025
71eac8d
fix knn combined query
Samiul-TheSoccerFan Jul 2, 2025
d71bf2c
fix unit tests
Samiul-TheSoccerFan Jul 2, 2025
675463c
fix lint issues
Samiul-TheSoccerFan Jul 2, 2025
201d27c
remove unused code
Samiul-TheSoccerFan Jul 2, 2025
daf2f6e
Update inference feature name
Samiul-TheSoccerFan Jul 3, 2025
2521b48
Remove double boosting issue from match
Samiul-TheSoccerFan Jul 3, 2025
61f9445
Fix double boosting in match test yaml file
Samiul-TheSoccerFan Jul 3, 2025
f4cadaa
move to bool level for match semantic boost
Samiul-TheSoccerFan Jul 3, 2025
08909de
fix double boosting for sparse vector
Samiul-TheSoccerFan Jul 3, 2025
37bfc43
fix double boosting for sparse vector in yaml test
Samiul-TheSoccerFan Jul 3, 2025
fa5cfe7
fix knn combined query
Samiul-TheSoccerFan Jul 3, 2025
0640631
fix knn combined query
Samiul-TheSoccerFan Jul 3, 2025
404efcf
fix sparse combined query
Samiul-TheSoccerFan Jul 3, 2025
f73285d
fix knn yaml test for combined query
Samiul-TheSoccerFan Jul 3, 2025
96f5aa6
refactoring unit tests
Samiul-TheSoccerFan Jul 4, 2025
3065e5b
linting
Samiul-TheSoccerFan Jul 4, 2025
828d8c2
fix match query unit test
Samiul-TheSoccerFan Jul 4, 2025
6a2e0a5
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 4, 2025
bde54df
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 7, 2025
d08dbdd
adding copy constructor for match query
Samiul-TheSoccerFan Jul 8, 2025
fa955c3
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 8, 2025
916b1cc
refactor copy match builder to intercepter
Samiul-TheSoccerFan Jul 8, 2025
d9ef867
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 8, 2025
cde55d1
resolve conflicts from main
Samiul-TheSoccerFan Jul 9, 2025
8ddda3c
[CI] Auto commit changes from spotless
Jul 9, 2025
873efdb
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 9, 2025
44b8aa9
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 10, 2025
104f16b
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 11, 2025
2a96d52
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 11, 2025
5dcfc1b
fix unit tests
Samiul-TheSoccerFan Jul 11, 2025
469f598
update yaml tests
Samiul-TheSoccerFan Jul 11, 2025
375ae36
fix match yaml test
Samiul-TheSoccerFan Jul 11, 2025
c81f184
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 14, 2025
394f43a
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 14, 2025
768b8f6
fix yaml tests with 4 digits error margin
Samiul-TheSoccerFan Jul 16, 2025
98cba31
unit tests are now more randomized
Samiul-TheSoccerFan Jul 16, 2025
6c6e7cf
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 16, 2025
4b0a6fe
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 16, 2025
c233ee9
Merge branch 'main' into fix-semantic-query-rewrite-boost-issue
elasticmachine Jul 17, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/129282.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 129282
summary: "Fix query rewrite logic to preserve `boosts` and `queryName` for `match`,\
\ `knn`, and `sparse_vector` queries on semantic_text fields"
area: Search
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public class InferenceFeatures implements FeatureSpecification {
private static final NodeFeature TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS = new NodeFeature(
"test_rule_retriever.with_indices_that_dont_return_rank_docs"
);
private static final NodeFeature SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX = new NodeFeature(
"semantic_query_rewrite_interceptors.propagate_boost_and_query_name_fix"
);
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");

Expand Down Expand Up @@ -68,7 +71,8 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
SEMANTIC_TEXT_INDEX_OPTIONS,
COHERE_V2_API,
SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS
SEMANTIC_TEXT_INDEX_OPTIONS_WITH_DEFAULTS,
SEMANTIC_QUERY_REWRITE_INTERCEPTORS_PROPAGATE_BOOST_AND_QUERY_NAME_FIX
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,20 @@ protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceI
assert (queryBuilder instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) queryBuilder;
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
QueryBuilder finalQueryBuilder;
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
Map.Entry<String, List<String>> inferenceIdIndex = inferenceIdsIndices.entrySet().iterator().next();
String searchInferenceId = inferenceIdIndex.getKey();
List<String> indices = inferenceIdIndex.getValue();
return buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
finalQueryBuilder = buildNestedQueryFromKnnVectorQuery(knnVectorQueryBuilder, indices, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(knnVectorQueryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.boost(queryBuilder.boost());
finalQueryBuilder.queryName(queryBuilder.queryName());
return finalQueryBuilder;
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
Expand Down Expand Up @@ -102,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ protected String getQuery(QueryBuilder queryBuilder) {

@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
return new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
SemanticQueryBuilder semanticQueryBuilder = new SemanticQueryBuilder(indexInformation.fieldName(), getQuery(queryBuilder), false);
semanticQueryBuilder.boost(queryBuilder.boost());
semanticQueryBuilder.queryName(queryBuilder.queryName());
return semanticQueryBuilder;
}

@Override
Expand All @@ -45,7 +48,10 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
InferenceIndexInformationForField indexInformation
) {
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
MatchQueryBuilder originalMatchQueryBuilder = (MatchQueryBuilder) queryBuilder;
// Create a copy for non-inference fields without boost and _name
MatchQueryBuilder matchQueryBuilder = copyMatchQueryBuilder(originalMatchQueryBuilder);

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
Expand All @@ -55,11 +61,33 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), matchQueryBuilder));
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}

private MatchQueryBuilder copyMatchQueryBuilder(MatchQueryBuilder queryBuilder) {
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(queryBuilder.fieldName(), queryBuilder.value());
matchQueryBuilder.operator(queryBuilder.operator());
matchQueryBuilder.prefixLength(queryBuilder.prefixLength());
matchQueryBuilder.maxExpansions(queryBuilder.maxExpansions());
matchQueryBuilder.fuzzyTranspositions(queryBuilder.fuzzyTranspositions());
matchQueryBuilder.lenient(queryBuilder.lenient());
matchQueryBuilder.zeroTermsQuery(queryBuilder.zeroTermsQuery());
matchQueryBuilder.analyzer(queryBuilder.analyzer());
matchQueryBuilder.minimumShouldMatch(queryBuilder.minimumShouldMatch());
matchQueryBuilder.fuzzyRewrite(queryBuilder.fuzzyRewrite());

if (queryBuilder.fuzziness() != null) {
matchQueryBuilder.fuzziness(queryBuilder.fuzziness());
}

matchQueryBuilder.autoGenerateSynonymsPhraseQuery(queryBuilder.autoGenerateSynonymsPhraseQuery());
return matchQueryBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@ protected String getQuery(QueryBuilder queryBuilder) {
@Override
protected QueryBuilder buildInferenceQuery(QueryBuilder queryBuilder, InferenceIndexInformationForField indexInformation) {
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();
QueryBuilder finalQueryBuilder;
if (inferenceIdsIndices.size() == 1) {
// Simple case, everything uses the same inference ID
String searchInferenceId = inferenceIdsIndices.keySet().iterator().next();
return buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
finalQueryBuilder = buildNestedQueryFromSparseVectorQuery(queryBuilder, searchInferenceId);
} else {
// Multiple inference IDs, construct a boolean query
return buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
finalQueryBuilder = buildInferenceQueryWithMultipleInferenceIds(queryBuilder, inferenceIdsIndices);
}
finalQueryBuilder.queryName(queryBuilder.queryName());
finalQueryBuilder.boost(queryBuilder.boost());
return finalQueryBuilder;
}

private QueryBuilder buildInferenceQueryWithMultipleInferenceIds(
Expand Down Expand Up @@ -79,7 +83,19 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
Map<String, List<String>> inferenceIdsIndices = indexInformation.getInferenceIdsIndices();

BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(createSubQueryForIndices(indexInformation.nonInferenceIndices(), sparseVectorQueryBuilder));
boolQueryBuilder.should(
createSubQueryForIndices(
indexInformation.nonInferenceIndices(),
new SparseVectorQueryBuilder(
sparseVectorQueryBuilder.getFieldName(),
sparseVectorQueryBuilder.getQueryVectors(),
sparseVectorQueryBuilder.getInferenceId(),
sparseVectorQueryBuilder.getQuery(),
sparseVectorQueryBuilder.shouldPruneTokens(),
sparseVectorQueryBuilder.getTokenPruningConfig()
)
)
);
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors.
for (String inferenceId : inferenceIdsIndices.keySet()) {
Expand All @@ -90,6 +106,8 @@ protected QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
)
);
}
boolQueryBuilder.boost(queryBuilder.boost());
boolQueryBuilder.queryName(queryBuilder.queryName());
return boolQueryBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY);
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

Expand All @@ -72,6 +80,14 @@ public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten()
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY);
KnnVectorQueryBuilder original = new KnnVectorQueryBuilder(FIELD_NAME, queryVectorBuilder, 10, 100, null);
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

Expand All @@ -82,14 +98,23 @@ private void testRewrittenInferenceQuery(QueryRewriteContext context, KnnVectorQ
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(original.boost(), intercepted.boost(), 0.0f);
assertEquals(original.queryName(), intercepted.queryName());
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);

NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
assertEquals(original.queryName(), nestedQueryBuilder.queryName());
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());

QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof KnnVectorQueryBuilder);
KnnVectorQueryBuilder knnVectorQueryBuilder = (KnnVectorQueryBuilder) innerQuery;
assertEquals(1.0f, knnVectorQueryBuilder.boost(), 0.0f);
assertNull(knnVectorQueryBuilder.queryName());
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), knnVectorQueryBuilder.getFieldName());
assertTrue(knnVectorQueryBuilder.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder);

TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder) knnVectorQueryBuilder
.queryVectorBuilder();
assertEquals(QUERY, textEmbeddingQueryVectorBuilder.getModelText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ public class SemanticMatchQueryRewriteInterceptorTests extends ESTestCase {

private static final String FIELD_NAME = "fieldName";
private static final String VALUE = "value";
private static final String QUERY_NAME = "match_query";
private static final float BOOST = 5.0f;

@Before
public void setup() {
Expand Down Expand Up @@ -79,6 +81,29 @@ public void testMatchQueryOnNonInferenceFieldRemainsMatchQuery() throws IOExcept
assertEquals(original, rewritten);
}

public void testBoostAndQueryNameInMatchQueryRewrite() throws IOException {
Map<String, InferenceFieldMetadata> inferenceFields = Map.of(
FIELD_NAME,
new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null)
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = createTestQueryBuilder();
original.boost(BOOST);
original.queryName(QUERY_NAME);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(BOOST, intercepted.boost(), 0.0f);
assertEquals(QUERY_NAME, intercepted.queryName());
assertTrue(intercepted.queryBuilder instanceof SemanticQueryBuilder);
SemanticQueryBuilder semanticQueryBuilder = (SemanticQueryBuilder) intercepted.queryBuilder;
assertEquals(FIELD_NAME, semanticQueryBuilder.getFieldName());
assertEquals(VALUE, semanticQueryBuilder.getQuery());
}

private MatchQueryBuilder createTestQueryBuilder() {
return new MatchQueryBuilder(FIELD_NAME, VALUE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,15 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException {
Expand All @@ -82,32 +76,52 @@ public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsIntercepted
);
QueryRewriteContext context = createQueryRewriteContext(inferenceFields);
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY);
if (randomBoolean()) {
float boost = randomFloatBetween(1, 10, randomBoolean());
original.boost(boost);
}
if (randomBoolean()) {
String queryName = randomAlphaOfLength(5);
original.queryName(queryName);
}
testRewrittenInferenceQuery(context, original);
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
}

private void testRewrittenInferenceQuery(QueryRewriteContext context, QueryBuilder original) throws IOException {
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to be intercepted, but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof InterceptedQueryBuilderWrapper
);
InterceptedQueryBuilderWrapper intercepted = (InterceptedQueryBuilderWrapper) rewritten;
assertEquals(original.boost(), intercepted.boost(), 0.0f);
assertEquals(original.queryName(), intercepted.queryName());

assertTrue(intercepted.queryBuilder instanceof NestedQueryBuilder);
NestedQueryBuilder nestedQueryBuilder = (NestedQueryBuilder) intercepted.queryBuilder;
assertEquals(SemanticTextField.getChunksFieldName(FIELD_NAME), nestedQueryBuilder.path());
assertEquals(original.boost(), nestedQueryBuilder.boost(), 0.0f);
assertEquals(original.queryName(), nestedQueryBuilder.queryName());

QueryBuilder innerQuery = nestedQueryBuilder.query();
assertTrue(innerQuery instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) innerQuery;
assertEquals(SemanticTextField.getEmbeddingsFieldName(FIELD_NAME), sparseVectorQueryBuilder.getFieldName());
assertEquals(INFERENCE_ID, sparseVectorQueryBuilder.getInferenceId());
assertEquals(QUERY, sparseVectorQueryBuilder.getQuery());
}

public void testSparseVectorQueryOnNonInferenceFieldRemainsUnchanged() throws IOException {
QueryRewriteContext context = createQueryRewriteContext(Map.of()); // No inference fields
QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY);
QueryBuilder rewritten = original.rewrite(context);
assertTrue(
"Expected query to remain sparse_vector but was [" + rewritten.getClass().getName() + "]",
rewritten instanceof SparseVectorQueryBuilder
);
assertEquals(original, rewritten);
assertEquals(1.0f, sparseVectorQueryBuilder.boost(), 0.0f);
assertNull(sparseVectorQueryBuilder.queryName());
}

private QueryRewriteContext createQueryRewriteContext(Map<String, InferenceFieldMetadata> inferenceFields) {
Expand Down
Loading
Loading