Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1319a40
Updated TestDenseInferenceServiceExtension to support bit embeddings
Mikep86 Feb 20, 2025
d0378c1
Updated semantic query builder tests to support bit embeddings
Mikep86 Feb 20, 2025
04719e2
Fixed how TestDenseInferenceServiceExtension generates bit embeddings
Mikep86 Feb 20, 2025
716200d
Update ShardBulkInferenceActionFilterIT to test using bit embeddings
Mikep86 Feb 20, 2025
48b5d0c
Updated SemanticTextFieldTests to generate well-formed bit embeddings
Mikep86 Feb 20, 2025
2bfae5e
Added helper/util methods
Mikep86 Feb 21, 2025
aefcefd
Added more helper/util methods
Mikep86 Feb 21, 2025
6cb4345
Spotless
Mikep86 Feb 21, 2025
e4c3921
Use helper/util methods
Mikep86 Feb 21, 2025
90fc4fc
Updated ShardBulkInferenceActionFilterTests to test using all element…
Mikep86 Feb 21, 2025
4aa6fa6
Updated SemanticInferenceMetadataFieldsRecoveryTests to test all elem…
Mikep86 Feb 21, 2025
1296b4d
Merge branch 'main' into semantic-text_bit-vector-tests
Mikep86 Feb 21, 2025
399faba
Update docs/changelog/123187.yaml
Mikep86 Feb 21, 2025
ab83cf7
Added YAML test
Mikep86 Feb 24, 2025
ca711fb
Updated changelog
Mikep86 Feb 24, 2025
a438ad8
Merge branch 'main' into semantic-text_bit-vector-tests
elasticmachine Feb 24, 2025
0815989
Merge branch 'main' into semantic-text_bit-vector-tests
Mikep86 Feb 25, 2025
dc3e387
Refactored to remove getEmbeddingLength and getDimensions methods
Mikep86 Feb 26, 2025
214c663
Remove TODO
Mikep86 Feb 26, 2025
693e1e6
Merge branch 'main' into semantic-text_bit-vector-tests
elasticmachine Feb 27, 2025
7fa409c
Merge branch 'main' into semantic-text_bit-vector-tests
elasticmachine Mar 3, 2025
e85ea6e
Added comment
Mikep86 Mar 6, 2025
3e94ab2
Merge branch 'main' into semantic-text_bit-vector-tests
Mikep86 Mar 6, 2025
9441a74
Updated SemanticInferenceMetadataFieldsRecoveryTests to not use cosin…
Mikep86 Mar 6, 2025
b24835a
Merge branch 'main' into semantic-text_bit-vector-tests
Mikep86 Mar 7, 2025
1614fc9
Don't use dot product
Mikep86 Mar 7, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/123187.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 123187
summary: Add bit vector support to semantic text
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.mapper.vectors;

import com.carrotsearch.randomizedtesting.RandomizedContext;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;

import org.elasticsearch.inference.SimilarityMeasure;

import java.util.List;
import java.util.Random;

public class DenseVectorFieldMapperTestUtils {
private DenseVectorFieldMapperTestUtils() {}

public static List<SimilarityMeasure> getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) {
return switch (elementType) {
case FLOAT, BYTE -> List.of(SimilarityMeasure.values());
case BIT -> List.of(SimilarityMeasure.L2_NORM);
};
}

public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
return switch (elementType) {
case FLOAT, BYTE -> dimensions;
case BIT -> {
assert dimensions % Byte.SIZE == 0;
yield dimensions / Byte.SIZE;
}
};
}

public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType elementType, int max) {
if (max < 1) {
throw new IllegalArgumentException("max must be at least 1");
}

return switch (elementType) {
case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max);
case BIT -> {
if (max < 8) {
throw new IllegalArgumentException("max must be at least 8 for bit vectors");
}

// Generate a random dimension count that is a multiple of 8
int maxEmbeddingLength = max / 8;
yield RandomNumbers.randomIntBetween(random(), 1, maxEmbeddingLength) * 8;
}
};
}

private static Random random() {
return RandomizedContext.current().getRandom();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public void infer(
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(makeResults(input, modelServiceSettings.dimensions()));
listener.onResponse(makeResults(input, modelServiceSettings));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
Expand Down Expand Up @@ -153,7 +153,7 @@ public void chunkedInfer(
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions()));
listener.onResponse(makeChunkedResults(input, modelServiceSettings));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
Expand All @@ -164,17 +164,17 @@ public void chunkedInfer(
}
}

private TextEmbeddingFloatResults makeResults(List<String> input, int dimensions) {
private TextEmbeddingFloatResults makeResults(List<String> input, ServiceSettings serviceSettings) {
List<TextEmbeddingFloatResults.Embedding> embeddings = new ArrayList<>();
for (String inputString : input) {
List<Float> floatEmbeddings = generateEmbedding(inputString, dimensions);
List<Float> floatEmbeddings = generateEmbedding(inputString, serviceSettings.dimensions(), serviceSettings.elementType());
embeddings.add(TextEmbeddingFloatResults.Embedding.of(floatEmbeddings));
}
return new TextEmbeddingFloatResults(embeddings);
}

private List<ChunkedInference> makeChunkedResults(List<String> input, int dimensions) {
TextEmbeddingFloatResults nonChunkedResults = makeResults(input, dimensions);
private List<ChunkedInference> makeChunkedResults(List<String> input, ServiceSettings serviceSettings) {
TextEmbeddingFloatResults nonChunkedResults = makeResults(input, serviceSettings);

var results = new ArrayList<ChunkedInference>();
for (int i = 0; i < input.size(); i++) {
Expand Down Expand Up @@ -204,7 +204,7 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* <ul>
* <li>Unique to the input</li>
* <li>Reproducible (i.e given the same input, the same embedding should be generated)</li>
* <li>Valid as both a float and byte embedding</li>
* <li>Valid for the provided element type</li>
* </ul>
* <p>
* The embedding is generated by:
Expand All @@ -219,32 +219,48 @@ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceS
* Since the hash code value, when interpreted as a string, is guaranteed to only contain digits and the "-" character, the UTF-8
* encoded byte array is guaranteed to only contain values in the standard ASCII table.
* </p>
* <p>
* If a bit embedding is required, the embedding length is 1/8 the dimension count because eight dimensions are encoded into each
* embedding byte.
* </p>
*
* @param input The input string
* @param dimensions The embedding dimension count
* @return An embedding
*/
private static List<Float> generateEmbedding(String input, int dimensions) {
List<Float> embedding = new ArrayList<>(dimensions);
private static List<Float> generateEmbedding(String input, int dimensions, DenseVectorFieldMapper.ElementType elementType) {
int embeddingLength = getEmbeddingLength(elementType, dimensions);
List<Float> embedding = new ArrayList<>(embeddingLength);

byte[] byteArray = Integer.toString(input.hashCode()).getBytes(StandardCharsets.UTF_8);
List<Float> embeddingValues = new ArrayList<>(byteArray.length);
for (byte value : byteArray) {
embeddingValues.add((float) value);
}

int remainingDimensions = dimensions;
while (remainingDimensions >= embeddingValues.size()) {
int remainingLength = embeddingLength;
while (remainingLength >= embeddingValues.size()) {
embedding.addAll(embeddingValues);
remainingDimensions -= embeddingValues.size();
remainingLength -= embeddingValues.size();
}
if (remainingDimensions > 0) {
embedding.addAll(embeddingValues.subList(0, remainingDimensions));
if (remainingLength > 0) {
embedding.addAll(embeddingValues.subList(0, remainingLength));
}

return embedding;
}

// Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions
private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
return switch (elementType) {
case FLOAT, BYTE -> dimensions;
case BIT -> {
assert dimensions % Byte.SIZE == 0;
yield dimensions / Byte.SIZE;
}
};
}

public static class Configuration {
public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
Expand Down Expand Up @@ -282,12 +298,6 @@ public record TestServiceSettings(

static final String NAME = "test_text_embedding_service_settings";

public TestServiceSettings {
if (elementType == DenseVectorFieldMapper.ElementType.BIT) {
throw new IllegalArgumentException("Test dense inference service does not yet support element type BIT");
}
}

public static TestServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.license.LicenseSettings;
import org.elasticsearch.plugins.Plugin;
Expand All @@ -35,7 +36,6 @@
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.junit.Before;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
Expand Down Expand Up @@ -71,15 +71,16 @@ public static Iterable<Object[]> parameters() throws Exception {

@Before
public void setup() throws Exception {
Utils.storeSparseModel(client());
Utils.storeDenseModel(
client(),
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values())),
// TODO: Allow element type BIT once TestDenseInferenceServiceExtension supports it
randomValueOtherThan(DenseVectorFieldMapper.ElementType.BIT, () -> randomFrom(DenseVectorFieldMapper.ElementType.values()))
DenseVectorFieldMapper.ElementType elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
// dot product means that we need normalized vectors; it's not worth doing that in this test
SimilarityMeasure similarity = randomValueOtherThan(
SimilarityMeasure.DOT_PRODUCT,
() -> randomFrom(DenseVectorFieldMapperTestUtils.getSupportedSimilarities(elementType))
);
int dimensions = DenseVectorFieldMapperTestUtils.randomCompatibleDimensions(elementType, 100);

Utils.storeSparseModel(client());
Utils.storeDenseModel(client(), dimensions, similarity, elementType);
}

@Override
Expand All @@ -89,7 +90,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateInferencePlugin.class);
return List.of(LocalStateInferencePlugin.class);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public Set<NodeFeature> getTestFeatures() {
SemanticInferenceMetadataFieldsMapper.INFERENCE_METADATA_FIELDS_ENABLED_BY_DEFAULT,
SEMANTIC_TEXT_HIGHLIGHTER_DEFAULT,
SEMANTIC_KNN_FILTER_FIX,
TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE
TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE,
SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
"semantic_text.always_emit_inference_id_fix"
);
public static final NodeFeature SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS = new NodeFeature("semantic_text.skip_inference_fields");
public static final NodeFeature SEMANTIC_TEXT_BIT_VECTOR_SUPPORT = new NodeFeature("semantic_text.bit_vector_support");

public static final String CONTENT_TYPE = "semantic_text";
public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID;
Expand Down Expand Up @@ -709,12 +710,12 @@ yield new SparseVectorQueryBuilder(

MlTextEmbeddingResults textEmbeddingResults = (MlTextEmbeddingResults) inferenceResults;
float[] inference = textEmbeddingResults.getInferenceAsFloat();
var inferenceLength = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
? inference.length * Byte.SIZE
int dimensions = modelSettings.elementType() == DenseVectorFieldMapper.ElementType.BIT
? inference.length * Byte.SIZE // Bit vectors encode 8 dimensions into each byte value
: inference.length;
if (inferenceLength != modelSettings.dimensions()) {
if (dimensions != modelSettings.dimensions()) {
throw new IllegalArgumentException(
generateDimensionCountMismatchMessage(inferenceLength, modelSettings.dimensions())
generateDimensionCountMismatchMessage(dimensions, modelSettings.dimensions())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.MockLicenseState;
Expand Down Expand Up @@ -650,7 +648,7 @@ private static class StaticModel extends TestModel {
}

public static StaticModel createRandomInstance() {
TestModel testModel = randomModel(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING));
TestModel testModel = TestModel.createRandomInstance();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test cleanup here 👍

return new StaticModel(
testModel.getInferenceEntityId(),
testModel.getTaskType(),
Expand All @@ -673,18 +671,4 @@ boolean hasResult(String text) {
return resultMap.containsKey(text);
}
}

private static TestModel randomModel(TaskType taskType) {
var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomIntBetween(2, 64) : null;
var similarity = taskType == TaskType.TEXT_EMBEDDING ? randomFrom(SimilarityMeasure.values()) : null;
var elementType = taskType == TaskType.TEXT_EMBEDDING ? DenseVectorFieldMapper.ElementType.FLOAT : null;
return new TestModel(
randomAlphaOfLength(4),
taskType,
randomAlphaOfLength(10),
new TestModel.TestServiceSettings(randomAlphaOfLength(4), dimensions, similarity, elementType),
new TestModel.TestTaskSettings(randomInt(3)),
new TestModel.TestSecretSettings(randomAlphaOfLength(4))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.SourceFieldMapper;
import org.elasticsearch.index.mapper.SourceToParse;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.translog.Translog;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.Model;
Expand All @@ -44,6 +43,7 @@

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingByte;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingFloat;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults;
import static org.hamcrest.Matchers.equalTo;
Expand All @@ -55,8 +55,8 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase
private final boolean useIncludesExcludes;

public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) {
this.model1 = randomModel(TaskType.TEXT_EMBEDDING);
this.model2 = randomModel(TaskType.SPARSE_EMBEDDING);
this.model1 = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT));
this.model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
this.useSynthetic = useSynthetic;
this.useIncludesExcludes = useIncludesExcludes;
}
Expand Down Expand Up @@ -218,22 +218,6 @@ private Translog.Snapshot newRandomSnapshot(
}
}

private static Model randomModel(TaskType taskType) {
var dimensions = taskType == TaskType.TEXT_EMBEDDING ? randomIntBetween(2, 64) : null;
var similarity = taskType == TaskType.TEXT_EMBEDDING
? randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values()))
: null;
var elementType = taskType == TaskType.TEXT_EMBEDDING ? DenseVectorFieldMapper.ElementType.BYTE : null;
return new TestModel(
randomAlphaOfLength(4),
taskType,
randomAlphaOfLength(10),
new TestModel.TestServiceSettings(randomAlphaOfLength(4), dimensions, similarity, elementType),
new TestModel.TestTaskSettings(randomInt(3)),
new TestModel.TestSecretSettings(randomAlphaOfLength(4))
);
}

private BytesReference randomSource() throws IOException {
var builder = JsonXContent.contentBuilder().startObject();
builder.field("field", randomAlphaOfLengthBetween(10, 30));
Expand Down Expand Up @@ -261,8 +245,8 @@ private static SemanticTextField randomSemanticText(
) throws IOException {
ChunkedInference results = switch (model.getTaskType()) {
case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) {
case BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs);
default -> throw new AssertionError("invalid element type: " + model.getServiceSettings().elementType().name());
case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs);
case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs);
};
case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false);
default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());
Expand Down
Loading