Add late chunking configuration for JinaAI embedding task settings#137263
Add late chunking configuration for JinaAI embedding task settings#137263dan-rubinstein merged 7 commits intoelastic:mainfrom
Conversation
|
Pinging @elastic/ml-core (Team:ML) |
|
Hi @dan-rubinstein, I've created a changelog YAML for you. |
jonathan-buttner
left a comment
There was a problem hiding this comment.
Looks good just a few comments.
| chunkingSettings | ||
| ).batchRequestsWithListeners(finalListener); | ||
|
|
||
| int expectedNumberOfBatches = batchChunksAcrossInputs ? 1 : 3; |
There was a problem hiding this comment.
Could you leave a comment in the code as to why it will either be 1 or 3 here?
There was a problem hiding this comment.
Sure, I'll add the following comment:
"There are 3 inputs that generate 8 chunks. If we are allowing batching of chunks across inputs, they will be placed into 1 batch. Otherwise, they will be split into 3 batches (1 per input)."
There was a problem hiding this comment.
It might be even clearer to use inputs.size() instead of 3, so that it's obvious where the value is coming from.
There was a problem hiding this comment.
Not strictly related to your changes but how about we fix this since we're touching the class? Basically we need to something like this https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java#L44-L58
There was a problem hiding this comment.
Good catch, I'll update this while I'm making changes to this file anyways.
| 0.123, | ||
| -0.123 | ||
| ] | ||
| if (Boolean.TRUE.equals(model.getTaskSettings().getLateChunking())) { |
There was a problem hiding this comment.
nit: Maybe move the functionality of queuing/choosing the response to a function to reduce the indentation.
There was a problem hiding this comment.
I'll move this logic to a helper function.
| } | ||
|
|
||
| public void testBatchChunksAcrossInputsIsFalseAndBatchesLessThanMaxChunkLimit_ThrowsAssertionError() { | ||
| int batchSize = randomIntBetween(1, 511); |
There was a problem hiding this comment.
Would it be worth defining this upper bound on batch size based on the batchSize value currently defined in testBatchChunksAcrossInputs()? If the batch size in that method is changed, the test might start failing without it being immediately obvious why.
There was a problem hiding this comment.
Sure, I'll move the batch size into a variable and reuse it across this and testBatchChunksAcrossInputs()
| int expectedNumberOfBatches = batchChunksAcrossInputs ? 1 : 3; | ||
| assertThat(batches, hasSize(expectedNumberOfBatches)); | ||
| if (batchChunksAcrossInputs) { | ||
| assertThat(batches.get(0).batch().inputs().get(), hasSize(8)); |
There was a problem hiding this comment.
It would be nice if we could tie this value of 8 directly to the input text somehow, since it's not immediately obvious where it's coming from and would become incorrect if the input was changed. Similarly, the "3, 1, 4" in the other branch of this if statement is a little disconnected from the input text. Maybe we could do something like this:
int maxChunkSize = 10;
var testSentence = IntStream.range(0, maxChunkSize).mapToObj(i -> "word" + i).collect(Collectors.joining(" ")) + ".";
var chunkingSettings = new SentenceBoundaryChunkingSettings(maxChunkSize, 0);
var batchSizes = List.of(3, 1, 4);
var totalBatchSizes = batchSizes.stream().mapToInt(Integer::intValue).sum();
List<ChunkInferenceInput> inputs = batchSizes.stream()
.map(i -> new ChunkInferenceInput(String.join(" ", Collections.nCopies(i, testSentence))))
.toList();
and use totalBatchSizes instead of 8.
There was a problem hiding this comment.
Sure, I'll update to using this proposed process.
...e/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java
Show resolved
Hide resolved
...e/src/main/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunker.java
Show resolved
Hide resolved
…nputs exceeding max word count
jonathan-buttner
left a comment
There was a problem hiding this comment.
Thanks for the changes. Left a few suggestions.
...org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java
Show resolved
Hide resolved
| public class ChunkerUtils { | ||
|
|
||
| public static int countWords(String text) { | ||
| BreakIterator wordIterator = BreakIterator.getWordInstance(); |
There was a problem hiding this comment.
A question came up previously about the languages this supports. Are we mainly targeting english? Or is there anything additional you are aware of that we can do to make countWords handle more languages?
I'm mainly just checking to see if there's a configuration we can pass to BreakIterator to make it support more languages.
There was a problem hiding this comment.
From the documentation it seems that it uses the default locales language. We usually pass Locale.ROOT into this function in other use cases so I'll make an update to be consistent here but I still think this will have the same behavior. If we need the user to be able to control which language it is using then we can consider that in a future change.
| MatcherAssert.assertThat( | ||
| xContentResult, | ||
| is( | ||
| Strings.format( |
There was a problem hiding this comment.
nit: If you want to make the expected value prettier (newline etc), I've used XContentHelper.stripWhitespace to clean it up before the comparison.
DonalEvans
left a comment
There was a problem hiding this comment.
Just a couple of small things, nothing mandatory.
.../test/java/org/elasticsearch/xpack/core/inference/chunking/EmbeddingRequestChunkerTests.java
Show resolved
Hide resolved
...org/elasticsearch/xpack/inference/services/jinaai/request/JinaAIEmbeddingsRequestEntity.java
Show resolved
Hide resolved
… chunking word count limit
jonathan-buttner
left a comment
There was a problem hiding this comment.
Thanks for the changes ✅
…lastic#137263) * Add late chunking configuration for JinaAI embedding task settings * Update docs/changelog/137263.yaml * Clean up tests and fix mutateInstance for JinaAIEmbeddingsTaskSettingsTests * Cleanup EmbeddingRequestChunker tests and disable late chunking for inputs exceeding max word count * Fixing test sentence generation * Adding test for generating multiple batches and clarification on late chunking word count limit
Description
This change adds the ability to pass in the
late_chunkingflag as part of the task settings for a JinaAI embeddings endpoint that will control whether JinaAI will late chunk for us. As part of our existing chunking process we will batch chunks across inputs into a single request. When late chunking, we need to avoid doing this as JinaAI will assume that all of the chunks in a single request are part of a single document. This change adds logic to avoid batching chunks across inputs when we are late chunking.Testing
late_chunkingset to null, true, and false.late_chunkingflag during an inference call with true and false works.