Skip to content
Merged
5 changes: 5 additions & 0 deletions docs/changelog/124313.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 124313
summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
area: Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.index.mapper.SourceFieldMapper;
Expand All @@ -44,6 +45,7 @@
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -85,7 +87,12 @@ public void setup() throws Exception {

@Override
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
return Settings.builder()
.put(otherSettings)
.put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
.put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
import java.util.function.Supplier;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

public class InferencePlugin extends Plugin
Expand Down Expand Up @@ -442,6 +443,7 @@ public List<Setting<?>> getSettings() {
settings.addAll(Truncator.getSettingsDefinitions());
settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
settings.add(SKIP_VALIDATE_AND_START);
settings.add(INDICES_INFERENCE_BATCH_SIZE);
settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());

return settings;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -267,37 +267,38 @@ private static List<Chunk> parseChunksArrayLegacy(XContentParser parser, ParserC
/**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/
public static List<Chunk> toSemanticTextFieldChunks(
String input,
int offsetAdjustment,
ChunkedInference results,
XContentType contentType,
boolean useLegacyFormat
) throws IOException {
public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
}
return chunks;
}

public static Chunk toSemanticTextFieldChunk(
String input,
int offsetAdjustment,
ChunkedInference.Chunk chunk,
boolean useLegacyFormat
) {
/**
* Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
*/
public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
String text = null;
int startOffset = -1;
int endOffset = -1;
if (useLegacyFormat) {
text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
} else {
startOffset = chunk.textOffset().start() + offsetAdjustment;
endOffset = chunk.textOffset().end() + offsetAdjustment;
int startOffset = chunk.textOffset().start() + offsetAdjustment;
int endOffset = chunk.textOffset().end() + offsetAdjustment;
return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
}

public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
throws IOException {
List<Chunk> chunks = new ArrayList<>();
Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
while (it.hasNext()) {
chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
}
return chunks;
}

return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
return new Chunk(text, -1, -1, chunk.bytesReference());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.IndexVersion;
Expand Down Expand Up @@ -66,12 +68,13 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
Expand Down Expand Up @@ -118,7 +121,7 @@ public void tearDownThreadPool() throws Exception {

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testFilterNoop() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand All @@ -144,7 +147,7 @@ public void testFilterNoop() throws Exception {
@SuppressWarnings({ "unchecked", "rawtypes" })
public void testLicenseInvalidForInference() throws InterruptedException {
StaticModel model = StaticModel.createRandomInstance();
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand Down Expand Up @@ -185,7 +188,6 @@ public void testInferenceNotFound() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -232,7 +234,6 @@ public void testItemFailures() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -303,7 +304,6 @@ public void testExplicitNull() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -374,7 +374,6 @@ public void testHandleEmptyInput() throws Exception {
ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
useLegacyFormat,
true
);
Expand Down Expand Up @@ -447,13 +446,7 @@ public void testManyRandomDocs() throws Exception {
modifiedRequests[id] = res[1];
}

ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
inferenceModelMap,
randomIntBetween(10, 30),
useLegacyFormat,
true
);
ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
Expand Down Expand Up @@ -487,7 +480,6 @@ public void testManyRandomDocs() throws Exception {
private static ShardBulkInferenceActionFilter createFilter(
ThreadPool threadPool,
Map<String, StaticModel> modelMap,
int batchSize,
boolean useLegacyFormat,
boolean isLicenseValidForInference
) {
Expand Down Expand Up @@ -554,18 +546,17 @@ private static ShardBulkInferenceActionFilter createFilter(
createClusterService(useLegacyFormat),
inferenceServiceRegistry,
modelRegistry,
licenseState,
batchSize
licenseState
);
}

private static ClusterService createClusterService(boolean useLegacyFormat) {
IndexMetadata indexMetadata = mock(IndexMetadata.class);
var settings = Settings.builder()
var indexSettings = Settings.builder()
.put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
.put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
.build();
when(indexMetadata.getSettings()).thenReturn(settings);
when(indexMetadata.getSettings()).thenReturn(indexSettings);

ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build());
when(project.index(anyString())).thenReturn(indexMetadata);
Expand All @@ -576,7 +567,10 @@ private static ClusterService createClusterService(boolean useLegacyFormat) {
ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenReturn(clusterState);

long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
when(clusterService.getSettings()).thenReturn(settings);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
return clusterService;
}

Expand All @@ -587,7 +581,8 @@ private static BulkItemRequest[] randomBulkItemRequest(
) throws IOException {
Map<String, Object> docMap = new LinkedHashMap<>();
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
XContentType requestContentType = randomFrom(XContentType.values());
// force JSON to avoid double/float conversions
XContentType requestContentType = XContentType.JSON;

Map<String, Object> inferenceMetadataFields = new HashMap<>();
for (var entry : fieldInferenceMap.values()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;

Expand Down Expand Up @@ -274,7 +275,7 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
while (inputsIt.hasNext() && chunkIt.hasNext()) {
String input = inputsIt.next();
var chunk = chunkIt.next();
chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));

// When using the inference metadata fields format, all the input values are concatenated so that the
// chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
Expand Down