Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/138457.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 138457
summary: Intercept filters to knn queries
area: Vector Search
type: bug
issues:
- 138410
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;

import java.io.IOException;
import java.util.Map;

/**
Expand All @@ -27,7 +28,7 @@ public interface QueryRewriteInterceptor {
* @param queryBuilder the original {@link QueryBuilder} to potentially rewrite
* @return the rewritten {@link QueryBuilder}, or the original instance if no rewrite was needed
*/
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder);
QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException;

/**
* Name of the query to be intercepted and rewritten.
Expand All @@ -52,7 +53,7 @@ public String getQueryName() {
}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException {
QueryRewriteInterceptor interceptor = interceptors.get(queryBuilder.getName());
if (interceptor != null) {
return interceptor.interceptAndRewrite(context, queryBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries)
return this;
}

public KnnVectorQueryBuilder setFilterQueries(List<QueryBuilder> filterQueries) {
Objects.requireNonNull(filterQueries);
this.filterQueries.clear();
this.filterQueries.addAll(filterQueries);
return this;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (queryVectorSupplier != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.QueryShardException;
import org.elasticsearch.index.query.RandomQueryBuilder;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.query.TermQueryBuilder;
Expand Down Expand Up @@ -565,4 +566,11 @@ public void testRewriteWithQueryVectorBuilder() throws Exception {
assertThat(rewritten.filterQueries(), hasSize(numFilters));
assertThat(rewritten.filterQueries(), equalTo(filters));
}

public void testSetFilterQueries() {
KnnVectorQueryBuilder knnQueryBuilder = doCreateTestQueryBuilder();
List<QueryBuilder> newFilters = randomList(5, () -> RandomQueryBuilder.createQuery(random()));
knnQueryBuilder.setFilterQueries(newFilters);
assertThat(knnQueryBuilder.filterQueries(), equalTo(newFilters));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

Expand Down Expand Up @@ -109,6 +110,7 @@ public Set<NodeFeature> getTestFeatures() {
SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS,
SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX,
InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS,
SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_FILTERS_REWRITE_INTERCEPTION_SUPPORTED,
TEXT_SIMILARITY_RERANKER_SNIPPETS,
ModelStats.SEMANTIC_TEXT_USAGE,
SEARCH_USAGE_EXTENDED_DATA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

import static org.elasticsearch.transport.RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY;
Expand Down Expand Up @@ -91,6 +93,33 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() {
return modelId != null ? new FullyQualifiedInferenceId(LOCAL_CLUSTER_GROUP_KEY, modelId) : null;
}

@Override
protected InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> customDoRewriteGetInferenceResults(
QueryRewriteContext queryRewriteContext
) throws IOException {
// knn query may contain filters that are also intercepted.
// We need to rewrite those here so that we can get inference results for them too.
return rewriteFilterQueries(queryRewriteContext);
}

private InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> rewriteFilterQueries(QueryRewriteContext queryRewriteContext)
throws IOException {
boolean filtersChanged = false;
List<QueryBuilder> rewrittenFilters = new ArrayList<>(originalQuery.filterQueries().size());
for (QueryBuilder filter : originalQuery.filterQueries()) {
QueryBuilder rewrittenFilter = filter.rewrite(queryRewriteContext);
if (rewrittenFilter != filter) {
filtersChanged = true;
}
rewrittenFilters.add(rewrittenFilter);
}
if (filtersChanged) {
originalQuery.setFilterQueries(rewrittenFilters);
return copy(inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
}
return this;
}

@Override
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
if (originalQuery.queryVector() == null && originalQuery.queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder == false) {
Expand Down Expand Up @@ -119,7 +148,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
}

@Override
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = this;
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
Expand All @@ -129,7 +158,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
}

@Override
protected QueryBuilder copy(
protected InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> copy(
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier,
boolean ccsRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ protected String getQuery() {
}

@Override
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = this;
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
Expand All @@ -75,7 +75,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
}

@Override
protected QueryBuilder copy(
protected InterceptedInferenceQueryBuilder<MatchQueryBuilder> copy(
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier,
boolean ccsRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ protected InterceptedInferenceQueryBuilder(
* @param queryRewriteContext The query rewrite context
* @return The query builder rewritten to a backwards-compatible form
*/
protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext);
protected abstract QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException;

/**
* Generate a copy of {@code this}.
*
* @param inferenceResultsMap The inference results map
* @param inferenceResultsMap The inference results map
* @param inferenceResultsMapSupplier The inference results map supplier
* @param ccsRequest Flag indicating if this is a CCS request
* @param ccsRequest Flag indicating if this is a CCS request
* @return A copy of {@code this} with the provided inference results map
*/
protected abstract QueryBuilder copy(
Expand Down Expand Up @@ -209,6 +209,15 @@ protected FullyQualifiedInferenceId getInferenceIdOverride() {
*/
protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {}

/**
* A hook for subclasses to do additional rewriting and inference result fetching while we are on the coordinator node.
* An example usage is {@link InterceptedInferenceKnnVectorQueryBuilder} which needs to rewrite the knn queries filters.
*/
protected InterceptedInferenceQueryBuilder<T> customDoRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext)
throws IOException {
return this;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (inferenceResultsMapSupplier != null) {
Expand Down Expand Up @@ -304,7 +313,7 @@ private QueryBuilder doRewriteBuildQuery(QueryRewriteContext indexMetadataContex
return queryFields(inferenceFieldsToQuery, nonInferenceFieldsToQuery, indexMetadataContext);
}

private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewrittenBwC = doRewriteBwC(queryRewriteContext);
if (rewrittenBwC != this) {
return rewrittenBwC;
Expand Down Expand Up @@ -344,6 +353,15 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
);
}

InterceptedInferenceQueryBuilder<T> rewritten = customDoRewriteGetInferenceResults(queryRewriteContext);
return rewritten.doRewriteWaitForInferenceResults(queryRewriteContext, inferenceIds, ccsRequest);
}

private QueryBuilder doRewriteWaitForInferenceResults(
QueryRewriteContext queryRewriteContext,
Set<FullyQualifiedInferenceId> inferenceIds,
boolean ccsRequest
) {
if (inferenceResultsMapSupplier != null) {
// Additional inference results have already been requested, and we are waiting for them to continue the rewrite process
return getNewInferenceResultsFromSupplier(inferenceResultsMapSupplier, this, m -> copy(m, null, ccsRequest));
Expand Down Expand Up @@ -376,7 +394,6 @@ private QueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewri
} else {
rewritten = copy(inferenceResultsMap, newInferenceResultsMapSupplier, ccsRequest);
}

return rewritten;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
}

@Override
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) throws IOException {
QueryBuilder rewritten = this;
if (queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS) == false) {
rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, originalQuery);
Expand All @@ -116,7 +116,7 @@ protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
}

@Override
protected QueryBuilder copy(
protected InterceptedInferenceQueryBuilder<SparseVectorQueryBuilder> copy(
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap,
SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier,
boolean ccsRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,50 @@

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class SemanticKnnVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_KNN_VECTOR_QUERY_FILTERS_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_knn_vector_query_filters_rewrite_interception_supported"
);

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) throws IOException {
if (queryBuilder instanceof KnnVectorQueryBuilder knnVectorQueryBuilder) {
return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder);
return interceptKnnQuery(context, knnVectorQueryBuilder);
} else {
throw new IllegalStateException("Unexpected query builder type: " + queryBuilder.getClass());
}
}

private static InterceptedInferenceKnnVectorQueryBuilder interceptKnnQuery(
QueryRewriteContext context,
KnnVectorQueryBuilder knnVectorQueryBuilder
) throws IOException {
boolean changed = false;
List<QueryBuilder> rewrittenFilters = new ArrayList<>(knnVectorQueryBuilder.filterQueries().size());
for (QueryBuilder filter : knnVectorQueryBuilder.filterQueries()) {
QueryBuilder rewritten = filter.rewrite(context);
if (rewritten != filter) {
changed = true;
}
rewrittenFilters.add(rewritten);
}
if (changed) {
knnVectorQueryBuilder.setFilterQueries(rewrittenFilters);
}
return new InterceptedInferenceKnnVectorQueryBuilder(knnVectorQueryBuilder);
}

@Override
public String getQueryName() {
return KnnVectorQueryBuilder.NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.TransportVersions.V_8_15_0;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;
Expand Down Expand Up @@ -335,7 +337,7 @@ protected abstract InterceptedInferenceQueryBuilder<T> createInterceptedQueryBui
Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap
);

protected abstract QueryRewriteInterceptor createQueryRewriteInterceptor();
protected abstract List<QueryRewriteInterceptor> createQueryRewriteInterceptors();

protected abstract TransportVersion getMinimalSupportedVersion();

Expand Down Expand Up @@ -427,8 +429,9 @@ protected QueryRewriteContext createQueryRewriteContext(
indexMetadata
);

QueryRewriteInterceptor interceptor = createQueryRewriteInterceptor();
Map<String, QueryRewriteInterceptor> interceptorMap = Map.of(interceptor.getQueryName(), interceptor);
QueryRewriteInterceptor interceptor = QueryRewriteInterceptor.multi(
createQueryRewriteInterceptors().stream().collect(Collectors.toMap(QueryRewriteInterceptor::getQueryName, Function.identity()))
);

return new QueryRewriteContext(
null,
Expand All @@ -438,7 +441,7 @@ protected QueryRewriteContext createQueryRewriteContext(
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
resolvedIndices,
null,
QueryRewriteInterceptor.multi(interceptorMap),
interceptor,
ccsMinimizeRoundTrips
);
}
Expand Down
Loading