Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
849147e
Add RerankRequestChunker
dan-rubinstein Jun 10, 2025
c41d54c
Merge branch 'main' into rerank-request-chunker
elasticmachine Jul 3, 2025
da4c939
Add chunking strategy generation
dan-rubinstein Jul 4, 2025
004ca8f
Merge branch 'main' into rerank-request-chunker
davidkyle Jul 18, 2025
5ec620a
Merge branch 'main' into rerank-request-chunker
elasticmachine Jul 30, 2025
4ff8eb0
Adding unit tests and fixing token/word ratio
dan-rubinstein Jul 23, 2025
ec78b87
Merge branch 'main' into rerank-request-chunker
elasticmachine Aug 13, 2025
9ef8917
Add configurable values for long document handling strategy and maxim…
dan-rubinstein Sep 8, 2025
24497ae
Adding back sentence overlap for rerank chunking strategy
dan-rubinstein Sep 11, 2025
1fea365
Merge branch 'main' into rerank-request-chunker
elasticmachine Sep 11, 2025
8396214
Merge branch 'main' into rerank-request-chunker
elasticmachine Sep 22, 2025
8b97711
Adding unit tests, transport version, and feature flag
dan-rubinstein Sep 18, 2025
833ef02
Update docs/changelog/130485.yaml
dan-rubinstein Sep 22, 2025
77701e1
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 25, 2025
344e121
Adding unit tests and refactoring code with clearer naming conventions
dan-rubinstein Sep 25, 2025
02c9d0a
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 29, 2025
d68bf09
Merge branch 'main' of github.com:elastic/elasticsearch into rerank-r…
dan-rubinstein Sep 29, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/130485.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 130485
summary: Add `RerankRequestChunker`
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9180000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.2.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ml_inference_google_model_garden_added,9179000
elastic_reranker_chunking_configuration,9180000
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public enum FeatureFlag {
"es.index_dimensions_tsid_optimization_feature_flag_enabled=true",
Version.fromString("9.2.0"),
null
);
),
ELASTIC_RERANKER_CHUNKING("es.elastic_reranker_chunking_long_documents=true", Version.fromString("9.2.0"), null);

public final String systemProperty;
public final Version from;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ public class ChunkingSettingsBuilder {
public static final SentenceBoundaryChunkingSettings DEFAULT_SETTINGS = new SentenceBoundaryChunkingSettings(250, 1);
// Old settings used for backward compatibility for endpoints created before 8.16 when default was changed
public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100);
public static final int ELASTIC_RERANKER_TOKEN_LIMIT = 512;
public static final int ELASTIC_RERANKER_EXTRA_TOKEN_COUNT = 3;
public static final float WORDS_PER_TOKEN = 0.75f;

public static ChunkingSettings fromMap(Map<String, Object> settings) {
return fromMap(settings, true);
Expand Down Expand Up @@ -51,4 +54,17 @@ public static ChunkingSettings fromMap(Map<String, Object> settings, boolean ret
case RECURSIVE -> RecursiveChunkingSettings.fromMap(new HashMap<>(settings));
};
}

public static ChunkingSettings buildChunkingSettingsForElasticRerank(int queryWordCount) {
var queryTokenCount = Math.ceil(queryWordCount / WORDS_PER_TOKEN);
var chunkSizeTokenCountWithFullQuery = (ELASTIC_RERANKER_TOKEN_LIMIT - ELASTIC_RERANKER_EXTRA_TOKEN_COUNT - queryTokenCount);

var maxChunkSizeTokenCount = Math.floor((float) ELASTIC_RERANKER_TOKEN_LIMIT / 2);
if (chunkSizeTokenCountWithFullQuery > maxChunkSizeTokenCount) {
maxChunkSizeTokenCount = chunkSizeTokenCountWithFullQuery;
}

var maxChunkSizeWordCount = (int) (maxChunkSizeTokenCount * WORDS_PER_TOKEN);
return new SentenceBoundaryChunkingSettings(maxChunkSizeWordCount, 1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;

import com.ibm.icu.text.BreakIterator;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class RerankRequestChunker {
private final List<String> inputs;
private final List<RerankChunks> rerankChunks;

public RerankRequestChunker(String query, List<String> inputs, Integer maxChunksPerDoc) {
this.inputs = inputs;
this.rerankChunks = chunk(inputs, buildChunkingSettingsForElasticRerank(query), maxChunksPerDoc);
}

private List<RerankChunks> chunk(List<String> inputs, ChunkingSettings chunkingSettings, Integer maxChunksPerDoc) {
var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy());
var chunks = new ArrayList<RerankChunks>();
for (int i = 0; i < inputs.size(); i++) {
var chunksForInput = chunker.chunk(inputs.get(i), chunkingSettings);
if (maxChunksPerDoc != null && chunksForInput.size() > maxChunksPerDoc) {
chunksForInput = chunksForInput.subList(0, maxChunksPerDoc);
}

for (var chunk : chunksForInput) {
chunks.add(new RerankChunks(i, inputs.get(i).substring(chunk.start(), chunk.end())));
}
}
return chunks;
}

public List<String> getChunkedInputs() {
List<String> chunkedInputs = new ArrayList<>();
for (RerankChunks chunk : rerankChunks) {
chunkedInputs.add(chunk.chunkString());
}

return chunkedInputs;
}

public ActionListener<InferenceServiceResults> parseChunkedRerankResultsListener(ActionListener<InferenceServiceResults> listener) {
return ActionListener.wrap(results -> {
if (results instanceof RankedDocsResults rankedDocsResults) {
listener.onResponse(parseRankedDocResultsForChunks(rankedDocsResults));

} else {
listener.onFailure(new IllegalArgumentException("Expected RankedDocsResults but got: " + results.getClass()));
}

}, listener::onFailure);
}

private RankedDocsResults parseRankedDocResultsForChunks(RankedDocsResults rankedDocsResults) {
List<RankedDocsResults.RankedDoc> topRankedDocs = new ArrayList<>();
Set<Integer> docIndicesSeen = new HashSet<>();

List<RankedDocsResults.RankedDoc> rankedDocs = new ArrayList<>(rankedDocsResults.getRankedDocs());
rankedDocs.sort((r1, r2) -> Float.compare(r2.relevanceScore(), r1.relevanceScore()));
for (RankedDocsResults.RankedDoc rankedDoc : rankedDocs) {
int chunkIndex = rankedDoc.index();
int docIndex = rerankChunks.get(chunkIndex).docIndex();

if (docIndicesSeen.contains(docIndex) == false) {
// Create a ranked doc with the full input string and the index for the document instead of the chunk
RankedDocsResults.RankedDoc updatedRankedDoc = new RankedDocsResults.RankedDoc(
docIndex,
rankedDoc.relevanceScore(),
inputs.get(docIndex)
);
topRankedDocs.add(updatedRankedDoc);
docIndicesSeen.add(docIndex);
}
}

return new RankedDocsResults(topRankedDocs);
}

public record RerankChunks(int docIndex, String chunkString) {};

private ChunkingSettings buildChunkingSettingsForElasticRerank(String query) {
var wordIterator = BreakIterator.getWordInstance();
wordIterator.setText(query);
var queryWordCount = ChunkerUtils.countWords(0, query.length(), wordIterator);
return ChunkingSettingsBuilder.buildChunkingSettingsForElasticRerank(queryWordCount);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,51 @@

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;

import java.io.IOException;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.ELASTIC_RERANKER_CHUNKING;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.RERANKER_ID;

public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {

public static final String NAME = "elastic_reranker_service_settings";

public static final String LONG_DOCUMENT_STRATEGY = "long_document_strategy";
public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc";

private static final TransportVersion ELASTIC_RERANKER_CHUNKING_CONFIGURATION = TransportVersion.fromName(
"elastic_reranker_chunking_configuration"
);

private final LongDocumentStrategy longDocumentStrategy;
private final Integer maxChunksPerDoc;

public static ElasticRerankerServiceSettings defaultEndpointSettings() {
return new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32));
}

public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
public ElasticRerankerServiceSettings(
ElasticsearchInternalServiceSettings other,
LongDocumentStrategy longDocumentStrategy,
Integer maxChunksPerDoc
) {
super(other);
this.longDocumentStrategy = longDocumentStrategy;
this.maxChunksPerDoc = maxChunksPerDoc;

}

private ElasticRerankerServiceSettings(
Expand All @@ -35,10 +61,32 @@ private ElasticRerankerServiceSettings(
AdaptiveAllocationsSettings adaptiveAllocationsSettings
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
this.longDocumentStrategy = null;
this.maxChunksPerDoc = null;
}

protected ElasticRerankerServiceSettings(
Integer numAllocations,
int numThreads,
String modelId,
AdaptiveAllocationsSettings adaptiveAllocationsSettings,
LongDocumentStrategy longDocumentStrategy,
Integer maxChunksPerDoc
) {
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings, null);
this.longDocumentStrategy = longDocumentStrategy;
this.maxChunksPerDoc = maxChunksPerDoc;
}

public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
super(in);
if (in.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) {
this.longDocumentStrategy = in.readOptionalEnum(LongDocumentStrategy.class);
this.maxChunksPerDoc = in.readOptionalInt();
} else {
this.longDocumentStrategy = null;
this.maxChunksPerDoc = null;
}
}

/**
Expand All @@ -48,21 +96,93 @@ public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
* {@link ValidationException} is thrown.
*
* @param map Source map containing the config
* @return The builder
* @return Parsed and validated service settings
*/
public static Builder fromRequestMap(Map<String, Object> map) {
public static ElasticRerankerServiceSettings fromMap(Map<String, Object> map) {
ValidationException validationException = new ValidationException();
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);

LongDocumentStrategy longDocumentStrategy = null;
Integer maxChunksPerDoc = null;
if (ELASTIC_RERANKER_CHUNKING.isEnabled()) {
longDocumentStrategy = extractOptionalEnum(
map,
LONG_DOCUMENT_STRATEGY,
ModelConfigurations.SERVICE_SETTINGS,
LongDocumentStrategy::fromString,
EnumSet.allOf(LongDocumentStrategy.class),
validationException
);

maxChunksPerDoc = extractOptionalPositiveInteger(
map,
MAX_CHUNKS_PER_DOC,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);

if (maxChunksPerDoc != null && (longDocumentStrategy == null || longDocumentStrategy == LongDocumentStrategy.TRUNCATE)) {
validationException.addValidationError(
"The [" + MAX_CHUNKS_PER_DOC + "] setting requires [" + LONG_DOCUMENT_STRATEGY + "] to be set to [chunk]"
);
}
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return baseSettings;
return new ElasticRerankerServiceSettings(baseSettings.build(), longDocumentStrategy, maxChunksPerDoc);
}

public LongDocumentStrategy getLongDocumentStrategy() {
return longDocumentStrategy;
}

public Integer getMaxChunksPerDoc() {
return maxChunksPerDoc;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().supports(ELASTIC_RERANKER_CHUNKING_CONFIGURATION)) {
out.writeOptionalEnum(longDocumentStrategy);
out.writeOptionalInt(maxChunksPerDoc);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
addInternalSettingsToXContent(builder, params);
if (longDocumentStrategy != null) {
builder.field(LONG_DOCUMENT_STRATEGY, longDocumentStrategy.strategyName);
}
if (maxChunksPerDoc != null) {
builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return ElasticRerankerServiceSettings.NAME;
}

public enum LongDocumentStrategy {
CHUNK("chunk"),
TRUNCATE("truncate");

public final String strategyName;

LongDocumentStrategy(String strategyName) {
this.strategyName = strategyName;
}

public static LongDocumentStrategy fromString(String name) {
return valueOf(name.trim().toUpperCase(Locale.ROOT));
}
}
}
Loading