Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/139463.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 139463
summary: Add bfloat16 support to `rank_vectors`
area: Vector Search
type: feature
issues: []
25 changes: 24 additions & 1 deletion docs/reference/elasticsearch/mapping-reference/rank-vectors.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,27 @@ PUT my-rank-vectors-float/_doc/1
```
% TESTSETUP

In addition to the `float` element type, `byte` and `bit` element types are also supported.
In addition to the `float` element type, `bfloat16`, `byte`, and `bit` element types are also supported.

Here is an example of using this field with `bfloat16` elements.
```console
PUT my-rank-vectors-bfloat16
{
"mappings": {
"properties": {
"my_vector": {
"type": "rank_vectors",
"element_type": "bfloat16"
}
}
}
}

PUT my-rank-vectors-bfloat16/_doc/1
{
"my_vector" : [[0.5, 10, 6], [-0.5, 10, 10]]
}
```

Here is an example of using this field with `byte` elements.

Expand Down Expand Up @@ -92,6 +112,9 @@ $$$rank-vectors-element-type$$$
`float`
: indexes a 4-byte floating-point value per dimension. This is the default value.

`bfloat16` {applies_to}`stack: ga 9.3`
: indexes a 2-byte floating-point value per dimension.

`byte`
: indexes a 1-byte integer value per dimension.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.script.field.vectors;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.codec.vectors.BFloat16;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.ShortBuffer;
import java.util.Iterator;

public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField {

private final BinaryDocValues input;
private final BinaryDocValues magnitudes;
private boolean decoded;
private final int dims;
private BytesRef value;
private BytesRef magnitudesValue;
private BFloat16VectorIterator vectorValues;
private int numVectors;
private float[] buffer;

public BFloat16RankVectorsDocValuesField(
BinaryDocValues input,
BinaryDocValues magnitudes,
String name,
DenseVectorFieldMapper.ElementType elementType,
int dims
) {
super(name, elementType);
this.input = input;
this.magnitudes = magnitudes;
this.dims = dims;
this.buffer = new float[dims];
}

@Override
public void setNextDocId(int docId) throws IOException {
decoded = false;
if (input.advanceExact(docId)) {
boolean magnitudesFound = magnitudes.advanceExact(docId);
assert magnitudesFound;

value = input.binaryValue();
assert value.length % (BFloat16.BYTES * dims) == 0;
numVectors = value.length / (BFloat16.BYTES * dims);
magnitudesValue = magnitudes.binaryValue();
assert magnitudesValue.length == (Float.BYTES * numVectors);
} else {
value = null;
magnitudesValue = null;
numVectors = 0;
}
}

@Override
public RankVectorsScriptDocValues toScriptDocValues() {
return new RankVectorsScriptDocValues(this, dims);
}

@Override
public boolean isEmpty() {
return value == null;
}

@Override
public RankVectors get() {
if (isEmpty()) {
return RankVectors.EMPTY;
}
decodeVectorIfNecessary();
return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims);
}

@Override
public RankVectors get(RankVectors defaultValue) {
if (isEmpty()) {
return defaultValue;
}
decodeVectorIfNecessary();
return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims);
}

@Override
public RankVectors getInternal() {
return get(null);
}

@Override
public int size() {
return value == null ? 0 : value.length / (BFloat16.BYTES * dims);
}

private void decodeVectorIfNecessary() {
if (decoded == false && value != null) {
vectorValues = new BFloat16VectorIterator(value, buffer, numVectors);
decoded = true;
}
}

public static class BFloat16VectorIterator implements VectorIterator<float[]> {
private final float[] buffer;
private final ShortBuffer vectorValues;
private final BytesRef vectorValueBytesRef;
private final int size;
private int idx = 0;

public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) {
assert vectorValues.length == (buffer.length * BFloat16.BYTES * size);
this.vectorValueBytesRef = vectorValues;
this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length)
.order(ByteOrder.LITTLE_ENDIAN)
.asShortBuffer();
this.size = size;
this.buffer = buffer;
}

@Override
public boolean hasNext() {
return idx < size;
}

@Override
public float[] next() {
if (hasNext() == false) {
throw new IllegalArgumentException("No more elements in the iterator");
}
BFloat16.bFloat16ToFloat(vectorValues, buffer);
idx++;
return buffer;
}

@Override
public Iterator<float[]> copy() {
return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size);
}

@Override
public void reset() {
idx = 0;
vectorValues.rewind();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.index.codec.vectors.BFloat16;
import org.elasticsearch.index.fielddata.FormattedDocValues;
import org.elasticsearch.index.fielddata.LeafFieldData;
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField;
import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField;
import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField;
import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField;
Expand Down Expand Up @@ -128,7 +130,46 @@ public Object nextValue() {
return vectors;
}
};
case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16");
case BFLOAT16 -> new FormattedDocValues() {
private final float[] vector = new float[dims];
private BytesRef ref = null;
private int numVecs = -1;
private final BinaryDocValues binary;
{
try {
binary = DocValues.getBinary(reader, field);
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values", e);
}
}

@Override
public boolean advanceExact(int docId) throws IOException {
if (binary == null || binary.advanceExact(docId) == false) {
return false;
}
ref = binary.binaryValue();
assert ref.length % (BFloat16.BYTES * dims) == 0;
numVecs = ref.length / (BFloat16.BYTES * dims);
return true;
}

@Override
public int docValueCount() {
return 1;
}

@Override
public Object nextValue() {
List<float[]> vectors = new ArrayList<>(numVecs);
VectorIterator<float[]> iterator = new BFloat16RankVectorsDocValuesField.BFloat16VectorIterator(ref, vector, numVecs);
while (iterator.hasNext()) {
float[] v = iterator.next();
vectors.add(Arrays.copyOf(v, v.length));
}
return vectors;
}
};
};
}

Expand All @@ -140,8 +181,8 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
return switch (elementType) {
case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
case BFLOAT16 -> new BFloat16RankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16");
};
} catch (IOException e) {
throw new IllegalStateException("Cannot load doc values for multi-vector field!", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.codec.vectors.BFloat16;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
Expand Down Expand Up @@ -77,9 +78,6 @@ public static class Builder extends FieldMapper.Builder {
"invalid element_type [" + o + "]; available types are " + namesToElementType.keySet()
);
}
if (elementType == ElementType.BFLOAT16) {
throw new MapperParsingException("Rank vectors does not support bfloat16");
}
return elementType;
},
m -> toType(m).fieldType().element.elementType(),
Expand Down Expand Up @@ -497,6 +495,13 @@ private List<List<?>> copyVectorsAsList() throws IOException {
}
vectors.add(vec);
}
case BFLOAT16 -> {
List<Float> vec = new ArrayList<>(dims);
for (int dim = 0; dim < dims; dim++) {
vec.add(BFloat16.bFloat16ToFloat(byteBuffer.getShort()));
}
vectors.add(vec);
}
case BYTE, BIT -> {
List<Byte> vec = new ArrayList<>(dims);
for (int dim = 0; dim < dims; dim++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ public static final class MaxSimInvHamming {

public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
RankVectorsDocValuesField field = (RankVectorsDocValuesField) scoreScript.field(fieldName);
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT
|| field.getElementType() == DenseVectorFieldMapper.ElementType.BFLOAT16) {
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
}
BytesOrList bytesOrList = parseBytes(queryVector);
Expand Down Expand Up @@ -351,13 +352,12 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel
yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list);
}
}
case FLOAT -> {
case FLOAT, BFLOAT16 -> {
if (queryVector instanceof List) {
yield new MaxSimFloatDotProduct(scoreScript, field, (List<List<Number>>) queryVector);
}
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
}
case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16");
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.stream.Stream;

import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomNormalizedVector;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToBFloat16List;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToList;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand All @@ -61,9 +62,12 @@ public class RankVectorsFieldMapperTests extends SyntheticVectorsMapperTestCase
private final int dims;

public RankVectorsFieldMapperTests() {
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT);
int baseDims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4;
int randomMultiplier = ElementType.FLOAT == elementType ? randomIntBetween(1, 64) : 1;
int randomMultiplier = switch (elementType) {
case FLOAT, BFLOAT16 -> randomIntBetween(1, 64);
case BYTE, BIT -> 1;
};
this.dims = baseDims * randomMultiplier;
}

Expand Down Expand Up @@ -97,11 +101,12 @@ protected Object getSampleValueForDocument(boolean binaryFormat) {
@Override
protected Object getSampleValueForDocument() {
int numVectors = randomIntBetween(1, 16);
return Stream.generate(
() -> elementType == ElementType.FLOAT
? convertToList(randomNormalizedVector(this.dims))
: convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims))
).limit(numVectors).toList();
return Stream.generate(switch (elementType) {
case FLOAT -> () -> convertToList(randomNormalizedVector(this.dims));
case BFLOAT16 -> () -> convertToBFloat16List(randomNormalizedVector(this.dims));
case BYTE -> () -> convertToList(randomByteArrayOfLength(dims));
case BIT -> () -> convertToList(randomByteArrayOfLength(dims / Byte.SIZE));
}).limit(numVectors).toList();
}

@Override
Expand All @@ -119,6 +124,21 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
checker.registerConflictCheck(
"element_type",
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "float")),
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "bfloat16"))
);
checker.registerConflictCheck(
"element_type",
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "byte")),
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "bfloat16"))
);
checker.registerConflictCheck(
"element_type",
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "float")),
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims * 8).field("element_type", "bit"))
);
checker.registerConflictCheck(
"element_type",
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "bfloat16")),
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims * 8).field("element_type", "bit"))
);
checker.registerConflictCheck(
Expand Down
Loading
Loading