Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/135342.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135342
summary: Add 'profile' support for knn query on HNSW with early termination
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -401,14 +398,13 @@ TopDocs doVectorQuery(byte[] vector, IndexSearcher searcher, Query filterQuery,
topK,
efSearch,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
knnQuery = PatienceKnnVectorQuery.fromByteQuery((KnnByteVectorQuery) knnQuery);
}
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling";
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
Expand All @@ -432,24 +428,20 @@ TopDocs doVectorQuery(float[] vector, IndexSearcher searcher, Query filterQuery,
topK,
efSearch,
filterQuery,
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy()
DenseVectorFieldMapper.FilterHeuristic.ACORN.getKnnSearchStrategy(),
indexType == KnnIndexTester.IndexType.HNSW && earlyTermination
);
if (indexType == KnnIndexTester.IndexType.HNSW && earlyTermination) {
knnQuery = PatienceKnnVectorQuery.fromFloatQuery((KnnFloatVectorQuery) knnQuery);
}
}
if (overSamplingFactor > 1f) {
// oversample the topK results to get more candidates for the final result
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(VECTOR_FIELD, vector, similarityFunction, this.topK, topK, knnQuery);
}
QueryProfiler profiler = new QueryProfiler();
TopDocs docs = searcher.search(knnQuery, this.topK);
if (knnQuery instanceof QueryProfilerProvider queryProfilerProvider) {
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
} else {
return docs;
}
assert knnQuery instanceof QueryProfilerProvider : "this knnQuery doesn't support profiling";
QueryProfilerProvider queryProfilerProvider = (QueryProfilerProvider) knnQuery;
queryProfilerProvider.profile(profiler);
return new TopDocs(new TotalHits(profiler.getVectorOpsCount(), docs.totalHits.relation()), docs.scoreDocs);
}

private static float checkResults(int[][] results, int[][] nn, int topK) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,8 @@ public void testHnswEarlyTerminationQuery() {
)
.sum();
assertTrue(
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lt vectorOps [" + vectorOpsSum + "]",
earlyTerminationVectorOpsSum < vectorOpsSum
// if both switch to brute-force due to excessive exploration, they will both equal to upperLimit
|| (earlyTerminationVectorOpsSum == vectorOpsSum && vectorOpsSum == upperLimit + 1)
"earlyTerminationVectorOps [" + earlyTerminationVectorOpsSum + "] is not lte vectorOps [" + vectorOpsSum + "]",
earlyTerminationVectorOpsSum <= vectorOpsSum
);
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PatienceKnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.knn.KnnSearchStrategy;
Expand Down Expand Up @@ -2366,6 +2363,7 @@ public Query createKnnQuery(
return new MatchNoDocsQuery("No data has been indexed for field [" + name() + "]");
}
KnnSearchStrategy knnSearchStrategy = heuristic.getKnnSearchStrategy();
hnswEarlyTermination &= canApplyPatienceQuery();
return switch (getElementType()) {
case BYTE -> createKnnByteQuery(
queryVector.asByteVector(),
Expand Down Expand Up @@ -2410,6 +2408,13 @@ private boolean isQuantized() {
return indexOptions != null && indexOptions.type != null && indexOptions.type.isQuantized();
}

private boolean canApplyPatienceQuery() {
return indexOptions instanceof HnswIndexOptions
|| indexOptions instanceof Int8HnswIndexOptions
|| indexOptions instanceof Int4HnswIndexOptions
|| indexOptions instanceof BBQHnswIndexOptions;
}

private Query createKnnBitQuery(
byte[] queryVector,
int k,
Expand All @@ -2433,11 +2438,17 @@ private Query createKnnBitQuery(
.build();
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
? new ESDiversifyingChildrenByteKnnVectorQuery(
name(),
queryVector,
filter,
k,
numCands,
parentFilter,
searchStrategy,
hnswEarlyTermination
)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
Expand Down Expand Up @@ -2477,11 +2488,17 @@ private Query createKnnByteQuery(
.build();
} else {
knnQuery = parentFilter != null
? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, k, numCands, parentFilter, searchStrategy)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
? new ESDiversifyingChildrenByteKnnVectorQuery(
name(),
queryVector,
filter,
k,
numCands,
parentFilter,
searchStrategy,
hnswEarlyTermination
)
: new ESKnnByteVectorQuery(name(), queryVector, k, numCands, filter, searchStrategy, hnswEarlyTermination);
}
if (similarityThreshold != null) {
knnQuery = new VectorSimilarityQuery(
Expand All @@ -2493,23 +2510,6 @@ private Query createKnnByteQuery(
return knnQuery;
}

private Query maybeWrapPatience(Query knnQuery) {
Query finalQuery = knnQuery;
if (knnQuery instanceof KnnByteVectorQuery knnByteVectorQuery && canApplyPatienceQuery()) {
finalQuery = PatienceKnnVectorQuery.fromByteQuery(knnByteVectorQuery);
} else if (knnQuery instanceof KnnFloatVectorQuery knnFloatVectorQuery && canApplyPatienceQuery()) {
finalQuery = PatienceKnnVectorQuery.fromFloatQuery(knnFloatVectorQuery);
}
return finalQuery;
}

private boolean canApplyPatienceQuery() {
return indexOptions instanceof HnswIndexOptions
|| indexOptions instanceof Int8HnswIndexOptions
|| indexOptions instanceof Int4HnswIndexOptions
|| indexOptions instanceof BBQHnswIndexOptions;
}

private Query createKnnFloatQuery(
float[] queryVector,
int k,
Expand Down Expand Up @@ -2586,10 +2586,7 @@ private Query createKnnFloatQuery(
parentFilter,
knnSearchStrategy
)
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy);
if (hnswEarlyTermination) {
knnQuery = maybeWrapPatience(knnQuery);
}
: new ESKnnFloatVectorQuery(name(), queryVector, adjustedK, numCands, filter, knnSearchStrategy, hnswEarlyTermination);
}
if (rescore) {
knnQuery = RescoreKnnVectorQuery.fromInnerQuery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESDiversifyingChildrenByteKnnVectorQuery(
String field,
Expand All @@ -28,9 +31,23 @@ public ESDiversifyingChildrenByteKnnVectorQuery(
int numCands,
BitSetProducer parentsFilter,
KnnSearchStrategy strategy
) {
this(field, query, childFilter, k, numCands, parentsFilter, strategy, false);
}

public ESDiversifyingChildrenByteKnnVectorQuery(
String field,
byte[] query,
Query childFilter,
int k,
int numCands,
BitSetProducer parentsFilter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, query, childFilter, numCands, parentsFilter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -48,4 +65,10 @@ public void profile(QueryProfiler queryProfiler) {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnByteVectorQuery extends KnnByteVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESKnnByteVectorQuery(String field, byte[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
this(field, target, k, numCands, filter, strategy, false);
}

public ESKnnByteVectorQuery(
String field,
byte[] target,
int k,
int numCands,
Query filter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -44,4 +60,10 @@ public Integer kParam() {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,35 @@

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.elasticsearch.search.profile.query.QueryProfiler;

public class ESKnnFloatVectorQuery extends KnnFloatVectorQuery implements QueryProfilerProvider {
private final int kParam;
private long vectorOpsCount;
private final boolean earlyTermination;

public ESKnnFloatVectorQuery(String field, float[] target, int k, int numCands, Query filter, KnnSearchStrategy strategy) {
this(field, target, k, numCands, filter, strategy, false);
}

public ESKnnFloatVectorQuery(
String field,
float[] target,
int k,
int numCands,
Query filter,
KnnSearchStrategy strategy,
boolean earlyTermination
) {
super(field, target, numCands, filter, strategy);
this.kParam = k;
this.earlyTermination = earlyTermination;
}

@Override
Expand All @@ -44,4 +60,10 @@ public int kParam() {
public KnnSearchStrategy getStrategy() {
return searchStrategy;
}

@Override
protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
KnnCollectorManager knnCollectorManager = super.getKnnCollectorManager(k, searcher);
return earlyTermination ? PatienceCollectorManager.wrap(knnCollectorManager, k) : knnCollectorManager;
}
}
Loading