Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5bebc1f
Wire "bulk with offset" functions, c++ implementation, ARM optimization
ldematte Nov 24, 2025
b8cbd5a
Merge remote-tracking branch 'upstream/main' into simd/arm-optimized-…
ldematte Nov 26, 2025
de51eea
Fix/reconcile/merge cpp files after merge
ldematte Nov 26, 2025
639bbf0
Another fix; add benchmarks to cover all paths
ldematte Nov 26, 2025
3369521
WIP: scoring on CPP side - fix signature to have pitch
ldematte Nov 26, 2025
451b924
Fixes to function signatures and CPP code
ldematte Nov 27, 2025
f232613
Add implementation for Java22 Int7SQVectorScorer
ldematte Nov 28, 2025
39d7f79
Remove score correction from native signature
ldematte Nov 28, 2025
03ec7f3
Update docs/changelog/138552.yaml
ldematte Nov 28, 2025
efe8d31
Better fallback from bulk trying to map the whole index + tests to co…
ldematte Nov 28, 2025
3aaf780
remove bulk fallback (unused)
ldematte Nov 28, 2025
3d24711
Add javadoc
ldematte Nov 28, 2025
47f4875
Fix x64 code
ldematte Nov 28, 2025
ada63b1
Merge branch 'simd/arm-optimized-bulk' of github.com:ldematte/elastic…
ldematte Nov 28, 2025
5458d93
Bump native libvec version
ldematte Nov 28, 2025
b485c66
Merge remote-tracking branch 'upstream/main' into simd/arm-optimized-…
ldematte Nov 28, 2025
5a3ab67
[CI] Auto commit changes from spotless
Nov 28, 2025
7fa26e8
Fix adjustment on non-aligned vectors
ldematte Nov 28, 2025
cfedef3
Merge branch 'simd/arm-optimized-bulk' of github.com:ldematte/elastic…
ldematte Nov 28, 2025
9fa7265
Merge branch 'main' into simd/arm-optimized-bulk
ldematte Nov 28, 2025
f3d14b3
Merge branch 'main' into simd/arm-optimized-bulk
ldematte Dec 3, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.elasticsearch.common.logging.LogConfigurator;
Expand Down Expand Up @@ -48,6 +49,7 @@
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.createRandomInt7VectorData;
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.getScorerFactoryOrDie;
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScoreSupplier;
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.luceneScorer;
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.readNodeCorrectionConstant;
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;
import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.vectorValues;
Expand Down Expand Up @@ -80,10 +82,10 @@ public class VectorScorerInt7uBulkBenchmark {

// 128k is typically enough to not fit in L1 (core) cache for most processors;
// 1.5M is typically enough to not fit in L2 (core) cache;
// 40M is typically enough to not fit in L3 cache
@Param({ "128000", "1500000", "30000000" })
// 130M is enough to not fit in L3 cache
@Param({ "128", "1500", "130000" })
public int numVectors;
public int numVectorsToScore = 20_000;
public int numVectorsToScore;

Path path;
Directory dir;
Expand All @@ -100,8 +102,12 @@ public class VectorScorerInt7uBulkBenchmark {
UpdateableRandomVectorScorer luceneDotScorer;
UpdateableRandomVectorScorer nativeDotScorer;

RandomVectorScorer luceneDotScorerQuery;
RandomVectorScorer nativeDotScorerQuery;

@Setup(Level.Trial)
public void setup() throws IOException {
numVectorsToScore = Math.min(numVectors, 20_000);
factory = getScorerFactoryOrDie();

var random = ThreadLocalRandom.current();
Expand All @@ -127,6 +133,17 @@ public void setup() throws IOException {
.orElseThrow()
.scorer();
nativeDotScorer.setScoringOrdinal(targetOrd);

if (supportsHeapSegments()) {
// setup for getInt7SQVectorScorer / query vector scoring
float[] queryVec = new float[dims];
for (int i = 0; i < dims; i++) {
queryVec[i] = random.nextFloat();
}
luceneDotScorerQuery = luceneScorer(dotProductValues, VectorSimilarityFunction.DOT_PRODUCT, queryVec);
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, dotProductValues, queryVec)
.orElseThrow();
}
}

@TearDown
Expand All @@ -151,6 +168,14 @@ public float[] dotProductLuceneMultipleRandom() throws IOException {
return scores;
}

@Benchmark
public float[] dotProductLuceneQueryMultipleRandom() throws IOException {
for (int v = 0; v < numVectorsToScore; v++) {
scores[v] = luceneDotScorerQuery.score(ordinals[v]);
}
return scores;
}

@Benchmark
public float[] dotProductNativeMultipleSequential() throws IOException {
for (int v = 0; v < numVectorsToScore; v++) {
Expand All @@ -167,6 +192,14 @@ public float[] dotProductNativeMultipleRandom() throws IOException {
return scores;
}

@Benchmark
public float[] dotProductNativeQueryMultipleRandom() throws IOException {
for (int v = 0; v < numVectorsToScore; v++) {
scores[v] = nativeDotScorerQuery.score(ordinals[v]);
}
return scores;
}

@Benchmark
public float[] dotProductNativeMultipleSequentialBulk() throws IOException {
nativeDotScorer.bulkScore(ids, scores, ordinals.length);
Expand All @@ -179,6 +212,12 @@ public float[] dotProductNativeMultipleRandomBulk() throws IOException {
return scores;
}

@Benchmark
public float[] dotProductNativeQueryMultipleRandomBulk() throws IOException {
nativeDotScorerQuery.bulkScore(ordinals, scores, ordinals.length);
return scores;
}

@Benchmark
public float[] dotProductScalarMultipleSequential() throws IOException {
var queryVector = dotProductValues.vectorValue(targetOrd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Arrays;

import static org.elasticsearch.benchmark.vector.scorer.BenchmarkUtils.supportsHeapSegments;

public class VectorScorerInt7uBulkBenchmarkTests extends ESTestCase {

final float delta = 1e-3f;
Expand Down Expand Up @@ -61,6 +63,11 @@ public void testDotProductRandom() throws Exception {
assertArrayEquals(expected, bench.dotProductLuceneMultipleRandom(), delta);
assertArrayEquals(expected, bench.dotProductNativeMultipleRandom(), delta);
assertArrayEquals(expected, bench.dotProductNativeMultipleRandomBulk(), delta);
if (supportsHeapSegments()) {
assertArrayEquals(expected, bench.dotProductLuceneQueryMultipleRandom(), delta);
assertArrayEquals(expected, bench.dotProductNativeQueryMultipleRandom(), delta);
assertArrayEquals(expected, bench.dotProductNativeQueryMultipleRandomBulk(), delta);
}
} finally {
bench.teardown();
}
Expand Down
5 changes: 5 additions & 0 deletions docs/changelog/138552.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 138552
summary: "[SIMD][ARM] Optimized native bulk dot product scoring for Int7"
area: Vector Search
type: enhancement
issues: []
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ configurations {
}

var zstdVersion = "1.5.5"
var vecVersion = "1.0.17"
var vecVersion = "1.0.18"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,32 @@ public interface VectorSimilarityFunctions {
*/
MethodHandle dotProductHandle7uBulk();

/**
* Produces a method handle which computes the dot product of several byte (unsigned
* int7) vectors. This bulk operation can be used to compute the dot product between a
* single query vector and a subset of vectors from a dataset (array of vectors). Each
* vector to include in the operation is identified by an offset inside the dataset.
*
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
*
* <p> The type of the method handle will have {@code void} as return type. The type of
* its arguments will be:
* <ol>
* <li>a {@code MemorySegment} containing the vector data bytes for several vectors;
* in other words, a contiguous array of vectors</li>
* <li>a {@code MemorySegment} containing the vector data bytes for a single ("query") vector</li>
* <li>an {@code int}, representing the dimensions of each vector</li>
* <li>an {@code int}, representing the width (in bytes) of each vector. Or, in other words,
* the distance in bytes between two vectors inside the first param's {@code MemorySegment}</li>
* <li>a {@code MemorySegment} containing the indices of the vectors inside the first param's array
* on which we'll compute the dot product</li>
* <li>an {@code int}, representing the number of vectors for which we'll compute the dot product
* (which is equal to the size - in number of elements - of the 5th and 7th {@code MemorySegment}s)</li>
* <li>a {@code MemorySegment}, into which the computed dot product float values will be stored</li>
* </ol>
*/
MethodHandle dotProductHandle7uBulkWithOffsets();

/**
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public final class JdkVectorLibrary implements VectorLibrary {

static final MethodHandle dot7u$mh;
static final MethodHandle dot7uBulk$mh;
static final MethodHandle dot7uBulkWithOffsets$mh;
static final MethodHandle sqr7u$mh;
static final MethodHandle cosf32$mh;
static final MethodHandle dotf32$mh;
Expand All @@ -59,6 +60,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
LinkerHelperUtil.critical()
);
dot7uBulkWithOffsets$mh = downcallHandle(
"vec_dot7u_bulk_offsets_2",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
LinkerHelperUtil.critical()
);
sqr7u$mh = downcallHandle(
"vec_sqr7u_2",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
Expand Down Expand Up @@ -90,6 +96,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
LinkerHelperUtil.critical()
);
dot7uBulkWithOffsets$mh = downcallHandle(
"vec_dot7u_bulk_offsets",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
LinkerHelperUtil.critical()
);
sqr7u$mh = downcallHandle(
"vec_sqr7u",
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
Expand Down Expand Up @@ -120,6 +131,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
}
dot7u$mh = null;
dot7uBulk$mh = null;
dot7uBulkWithOffsets$mh = null;
sqr7u$mh = null;
cosf32$mh = null;
dotf32$mh = null;
Expand Down Expand Up @@ -161,6 +173,18 @@ static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int c
dot7uBulk(a, b, length, count, result);
}

static void dotProduct7uBulkWithOffsets(
MemorySegment a,
MemorySegment b,
int length,
int pitch,
MemorySegment offsets,
int count,
MemorySegment result
) {
dot7uBulkWithOffsets(a, b, length, pitch, offsets, count, result);
}

/**
* Computes the square distance of given unsigned int7 byte vectors.
*
Expand Down Expand Up @@ -237,6 +261,22 @@ private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int
}
}

private static void dot7uBulkWithOffsets(
MemorySegment a,
MemorySegment b,
int length,
int pitch,
MemorySegment offsets,
int count,
MemorySegment result
) {
try {
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, pitch, offsets, count, result);
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
try {
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
Expand Down Expand Up @@ -271,6 +311,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {

static final MethodHandle DOT_HANDLE_7U;
static final MethodHandle DOT_HANDLE_7U_BULK;
static final MethodHandle DOT_HANDLE_7U_BULK_WITH_OFFSETS;
static final MethodHandle SQR_HANDLE_7U;
static final MethodHandle COS_HANDLE_FLOAT32;
static final MethodHandle DOT_HANDLE_FLOAT32;
Expand All @@ -286,6 +327,21 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
mt = MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class);
DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);

DOT_HANDLE_7U_BULK_WITH_OFFSETS = lookup.findStatic(
JdkVectorSimilarityFunctions.class,
"dotProduct7uBulkWithOffsets",
MethodType.methodType(
void.class,
MemorySegment.class,
MemorySegment.class,
int.class,
int.class,
MemorySegment.class,
int.class,
MemorySegment.class
)
);

mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
Expand All @@ -305,6 +361,11 @@ public MethodHandle dotProductHandle7uBulk() {
return DOT_HANDLE_7U_BULK;
}

@Override
public MethodHandle dotProductHandle7uBulkWithOffsets() {
return DOT_HANDLE_7U_BULK_WITH_OFFSETS;
}

@Override
public MethodHandle squareDistanceHandle7u() {
return SQR_HANDLE_7U;
Expand Down
Loading