Skip to content
5 changes: 5 additions & 0 deletions docs/changelog/137263.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 137263
summary: Add late chunking configuration for JinaAI embedding task settings
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9222000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.3.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
index_created_transport_version,9221000
jina_ai_configurable_late_chunking,9222000
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@

import com.ibm.icu.text.BreakIterator;

import java.util.Locale;

public class ChunkerUtils {

public static int countWords(String text) {
BreakIterator wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
wordIterator.setText(text);
return countWords(0, text.length(), wordIterator);
}

// setText() should be applied before using this function.
static int countWords(int start, int end, BreakIterator wordIterator) {
assert start < end;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,26 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener<Inferen
private ActionListener<List<ChunkedInference>> finalListener;

public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch) {
this(inputs, maxNumberOfInputsPerBatch, null);
this(inputs, maxNumberOfInputsPerBatch, true, null);
}

public EmbeddingRequestChunker(List<ChunkInferenceInput> inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) {
this(inputs, maxNumberOfInputsPerBatch, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap));
this(inputs, maxNumberOfInputsPerBatch, true, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap));
}

public EmbeddingRequestChunker(
List<ChunkInferenceInput> inputs,
int maxNumberOfInputsPerBatch,
ChunkingSettings defaultChunkingSettings
) {
this(inputs, maxNumberOfInputsPerBatch, true, defaultChunkingSettings);
}

public EmbeddingRequestChunker(
List<ChunkInferenceInput> inputs,
int maxNumberOfInputsPerBatch,
boolean batchChunksAcrossInputs,
ChunkingSettings defaultChunkingSettings
) {
this.resultEmbeddings = new ArrayList<>(inputs.size());
this.resultOffsetStarts = new ArrayList<>(inputs.size());
Expand Down Expand Up @@ -147,13 +156,23 @@ public EmbeddingRequestChunker(
}
}

AtomicInteger counter = new AtomicInteger();
this.batchRequests = allRequests.stream()
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
.values()
.stream()
.map(BatchRequest::new)
.toList();
if (batchChunksAcrossInputs) {
AtomicInteger counter = new AtomicInteger();
this.batchRequests = allRequests.stream()
.collect(Collectors.groupingBy(it -> counter.getAndIncrement() / maxNumberOfInputsPerBatch))
.values()
.stream()
.map(BatchRequest::new)
.toList();
} else {
assert (maxNumberOfInputsPerBatch >= MAX_CHUNKS);
this.batchRequests = allRequests.stream()
.collect(Collectors.groupingBy(Request::inputIndex))
.values()
.stream()
.map(BatchRequest::new)
.toList();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.test.ESTestCase;

import java.util.Locale;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.xpack.core.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT;

Expand Down Expand Up @@ -85,6 +87,12 @@ public void testCountWords_WithSymbols() {
}
}

public void testCountWords_GivenStringCountsAllWords() {
int wordCount = randomIntBetween(1, 100);
var testText = IntStream.range(0, wordCount).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + ".";
assertEquals(wordCount, ChunkerUtils.countWords(testText));
}

private int[] sentenceSizes(String text) {
var sentences = text.split("\\.\\s+");
var lengths = new int[sentences.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.inference.InferenceString.DataType.TEXT;
import static org.elasticsearch.inference.InferenceString.toStringList;
Expand All @@ -38,6 +41,8 @@

public class EmbeddingRequestChunkerTests extends ESTestCase {

private static final int MAX_BATCH_SIZE = 512;

public void testEmptyInput_WordChunker() {
var batches = new EmbeddingRequestChunker<>(List.of(), 100, 100, 10).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
Expand Down Expand Up @@ -943,6 +948,78 @@ public void testMergingListener_Sparse() {
}
}

public void testBatchChunksAcrossInputsIsFalse_DoesNotBatchChunksFromSeparateInputs() {
testBatchChunksAcrossInputs(false, List.of(3, 1, 4));
}

public void testBatchChunksAcrossInputsIsTrue_DoesBatchChunksFromSeparateInputs() {
testBatchChunksAcrossInputs(true, List.of(3, 1, 4));
}

public void testBatchChunksAcrossInputsIsTrue_GeneratesMultipleBatches() {
testBatchChunksAcrossInputs(true, List.of(200, 200, 200));
}

public void testBatchChunksAcrossInputsIsFalseAndBatchesLessThanMaxChunkLimit_ThrowsAssertionError() {
int batchSize = randomIntBetween(1, MAX_BATCH_SIZE - 1);
List<ChunkInferenceInput> inputs = List.of(new ChunkInferenceInput("This is a test sentence with ten words in total. "));
var chunkingSettings = new SentenceBoundaryChunkingSettings(10, 0);
expectThrows(
AssertionError.class,
() -> new EmbeddingRequestChunker<>(inputs, batchSize, false, chunkingSettings).batchRequestsWithListeners(testListener())
);
}

private void testBatchChunksAcrossInputs(boolean batchChunksAcrossInputs, List<Integer> batchSizes) {
int maxChunkSize = 10;
var testSentence = IntStream.range(0, maxChunkSize).mapToObj(i -> "Word" + i).collect(Collectors.joining(" ")) + ".";
var chunkingSettings = new SentenceBoundaryChunkingSettings(maxChunkSize, 0);
var totalBatchSizes = batchSizes.stream().mapToInt(Integer::intValue).sum();
List<ChunkInferenceInput> inputs = batchSizes.stream()
.map(i -> new ChunkInferenceInput(String.join(" ", Collections.nCopies(i, testSentence))))
.toList();

var finalListener = testListener();
List<EmbeddingRequestChunker.BatchRequestAndListener> batches = new EmbeddingRequestChunker<>(
inputs,
MAX_BATCH_SIZE,
batchChunksAcrossInputs,
chunkingSettings
).batchRequestsWithListeners(finalListener);

// If we are batching chunks across inputs, we expect the batches to be filled up to the max batch size.
// Otherwise, we expect one batch per input.
int expectedNumberOfBatches = batchChunksAcrossInputs ? (int) Math.ceil((double) totalBatchSizes / MAX_BATCH_SIZE) : inputs.size();
assertThat(batches, hasSize(expectedNumberOfBatches));
if (batchChunksAcrossInputs) {
for (int i = 0; i < batches.size(); i++) {
var expectedBatchSize = i < batches.size() - 1 ? MAX_BATCH_SIZE : totalBatchSizes - (MAX_BATCH_SIZE * (batches.size() - 1));
assertThat(batches.get(i).batch().inputs().get(), hasSize(expectedBatchSize));
batches.get(i)
.listener()
.onResponse(
new DenseEmbeddingFloatResults(
List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloatBetween(0, 1, true) }))
)
);
}
} else {
for (int i = 0; i < batches.size(); i++) {
assertThat(batches.get(i).batch().inputs().get(), hasSize(batchSizes.get(i)));
batches.get(i)
.listener()
.onResponse(
new DenseEmbeddingFloatResults(
List.of(new DenseEmbeddingFloatResults.Embedding(new float[] { randomFloatBetween(0, 1, true) }))
)
);
}
}

assertNotNull(finalListener.results);
assertThat(finalListener.results, hasSize(3));
}

public void testListenerErrorsWithWrongNumberOfResponses() {
List<ChunkInferenceInput> inputs = List.of(
new ChunkInferenceInput("1st small"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingType;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
Expand Down Expand Up @@ -276,9 +277,16 @@ protected void doChunkedInfer(
JinaAIModel jinaaiModel = (JinaAIModel) model;
var actionCreator = new JinaAIActionCreator(getSender(), getServiceComponents());

boolean batchChunksAcrossInputs = true;
if (jinaaiModel.getTaskSettings() instanceof JinaAIEmbeddingsTaskSettings jinaAIEmbeddingsTaskSettings) {
batchChunksAcrossInputs = jinaAIEmbeddingsTaskSettings.getLateChunking() == null
|| jinaAIEmbeddingsTaskSettings.getLateChunking() == false;
}

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs,
EMBEDDING_MAX_BATCH_SIZE,
batchChunksAcrossInputs,
jinaaiModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES;

Expand All @@ -36,6 +37,11 @@ public class JinaAIEmbeddingsTaskSettings implements TaskSettings {
public static final String NAME = "jinaai_embeddings_task_settings";
public static final JinaAIEmbeddingsTaskSettings EMPTY_SETTINGS = new JinaAIEmbeddingsTaskSettings((InputType) null);
static final String INPUT_TYPE = "input_type";
static final String LATE_CHUNKING = "late_chunking";

protected static final TransportVersion JINA_AI_CONFIGURABLE_LATE_CHUNKING = TransportVersion.fromName(
"jina_ai_configurable_late_chunking"
);

public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
Expand All @@ -53,11 +59,13 @@ public static JinaAIEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
validationException
);

Boolean lateChunking = extractOptionalBoolean(map, LATE_CHUNKING, validationException);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new JinaAIEmbeddingsTaskSettings(inputType);
return new JinaAIEmbeddingsTaskSettings(inputType, lateChunking);
}

/**
Expand All @@ -76,8 +84,9 @@ public static JinaAIEmbeddingsTaskSettings of(
JinaAIEmbeddingsTaskSettings requestTaskSettings
) {
var inputTypeToUse = getValidInputType(originalSettings, requestTaskSettings);
var lateChunkingToUse = requestTaskSettings.lateChunking != null ? requestTaskSettings.lateChunking : originalSettings.lateChunking;

return new JinaAIEmbeddingsTaskSettings(inputTypeToUse);
return new JinaAIEmbeddingsTaskSettings(inputTypeToUse, lateChunkingToUse);
}

private static InputType getValidInputType(
Expand All @@ -94,14 +103,28 @@ private static InputType getValidInputType(
}

private final InputType inputType;
private final Boolean lateChunking;

public JinaAIEmbeddingsTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalEnum(InputType.class));
this.inputType = in.readOptionalEnum(InputType.class);

if (in.getTransportVersion().supports(JINA_AI_CONFIGURABLE_LATE_CHUNKING)) {
this.lateChunking = in.readOptionalBoolean();
} else {
this.lateChunking = null;
}
}

public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType, Boolean lateChunking) {
validateInputType(inputType);
this.inputType = inputType;
this.lateChunking = lateChunking;
}

public JinaAIEmbeddingsTaskSettings(@Nullable InputType inputType) {
validateInputType(inputType);
this.inputType = inputType;
this.lateChunking = null;
}

private static void validateInputType(InputType inputType) {
Expand All @@ -114,7 +137,7 @@ private static void validateInputType(InputType inputType) {

@Override
public boolean isEmpty() {
return inputType == null;
return inputType == null && lateChunking == null;
}

@Override
Expand All @@ -124,6 +147,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_TYPE, inputType);
}

if (lateChunking != null) {
builder.field(LATE_CHUNKING, lateChunking);
}

builder.endObject();
return builder;
}
Expand All @@ -132,6 +159,10 @@ public InputType getInputType() {
return inputType;
}

public Boolean getLateChunking() {
return lateChunking;
}

@Override
public String getWriteableName() {
return NAME;
Expand All @@ -145,19 +176,23 @@ public TransportVersion getMinimalSupportedVersion() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(inputType);

if (out.getTransportVersion().supports(JINA_AI_CONFIGURABLE_LATE_CHUNKING)) {
out.writeOptionalBoolean(lateChunking);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
JinaAIEmbeddingsTaskSettings that = (JinaAIEmbeddingsTaskSettings) o;
return Objects.equals(inputType, that.inputType);
return Objects.equals(inputType, that.inputType) && Objects.equals(lateChunking, that.lateChunking);
}

@Override
public int hashCode() {
return Objects.hash(inputType);
return Objects.hash(inputType, lateChunking);
}

@Override
Expand Down
Loading