Skip to content
6 changes: 6 additions & 0 deletions docs/changelog/138047.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 138047
summary: Add configurable `max_batch_size` for `GoogleVertexAI` embedding service
settings
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.xcontent.ToXContentObject;

import java.util.Map;

public interface ServiceSettings extends ToXContentObject, VersionedNamedWriteable, FilteredXContent {

/**
Expand Down Expand Up @@ -61,4 +63,8 @@ default DenseVectorFieldMapper.ElementType elementType() {
*/
@Nullable
String modelId();

default ServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
return this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9241000
2 changes: 1 addition & 1 deletion server/src/main/resources/transport/upper_bounds/9.3.csv
Original file line number Diff line number Diff line change
@@ -1 +1 @@
esql_exponential_histogram_supported_version,9240000
google_vertex_ai_configurable_max_batch_size,9241000
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ private Model combineExistingModelWithNewSettings(
newSecretSettings = existingSecretSettings.newSecretSettings(settingsToUpdate.serviceSettings());
}
if (settingsToUpdate.serviceSettings() != null) {
// In cluster services can have their deployment settings updated, so this is a special case
if (newServiceSettings instanceof ElasticsearchInternalServiceSettings elasticServiceSettings) {
newServiceSettings = elasticServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
}
newServiceSettings = newServiceSettings.updateServiceSettings(settingsToUpdate.serviceSettings());
}
if (settingsToUpdate.taskSettings() != null && existingTaskSettings != null) {
newTaskSettings = existingTaskSettings.updatedTaskSettings(settingsToUpdate.taskSettings());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,22 @@ public static Integer extractOptionalPositiveInteger(
return extractOptionalInteger(map, settingName, scope, validationException, true);
}

public static Integer extractOptionalPositiveIntegerLessThanOrEqualToMax(
Map<String, Object> map,
String settingName,
int maxValue,
String scope,
ValidationException validationException
) {
Integer optionalField = extractOptionalPositiveInteger(map, settingName, scope, validationException);

if (optionalField != null && optionalField > maxValue) {
validationException.addValidationError(mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, optionalField, maxValue));
}

return optionalField;
}

public static Integer extractOptionalInteger(
Map<String, Object> map,
String settingName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.minimumCompatible();
}

public ElasticsearchInternalServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
@Override
public ServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
var validationException = new ValidationException();
var mutableServiceSettings = new HashMap<>(serviceSettings);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,13 @@ protected void doChunkedInfer(
ActionListener<List<ChunkedInference>> listener
) {
GoogleVertexAiModel googleVertexAiModel = (GoogleVertexAiModel) model;
GoogleVertexAiEmbeddingsServiceSettings serviceSettings = (GoogleVertexAiEmbeddingsServiceSettings) googleVertexAiModel
.getServiceSettings();
var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());

List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
inputs,
EMBEDDING_MAX_BATCH_SIZE,
serviceSettings.maxBatchSize() == null ? EMBEDDING_MAX_BATCH_SIZE : serviceSettings.maxBatchSize(),
googleVertexAiModel.getConfigurations().getChunkingSettings()
).batchRequestsWithListeners(listener);

Expand All @@ -306,6 +308,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
serviceSettings.dimensionsSetByUser(),
serviceSettings.maxInputTokens(),
embeddingSize,
serviceSettings.maxBatchSize(),
serviceSettings.similarity(),
serviceSettings.rateLimitSettings()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ public class GoogleVertexAiServiceFields {
public static final String URL_SETTING_NAME = "url";
public static final String STREAMING_URL_SETTING_NAME = "streaming_url";
public static final String PROVIDER_SETTING_NAME = "provider";
public static final String MAX_BATCH_SIZE = "max_batch_size";

/**
* According to https://cloud.google.com/vertex-ai/docs/quotas#text-embedding-limits the limit is `250`.
*/
static final int EMBEDDING_MAX_BATCH_SIZE = 250;
public static final int EMBEDDING_MAX_BATCH_SIZE = 250;

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

Expand All @@ -36,9 +37,12 @@
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveIntegerLessThanOrEqualToMax;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.MAX_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID;

public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObject
Expand All @@ -53,6 +57,10 @@ public class GoogleVertexAiEmbeddingsServiceSettings extends FilteredXContentObj
// See online prediction requests per minute: https://cloud.google.com/vertex-ai/docs/quotas.
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(30_000);

protected static final TransportVersion GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE = TransportVersion.fromName(
"google_vertex_ai_configurable_max_batch_size"
);

public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();

Expand All @@ -67,6 +75,13 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
);
SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
Integer maxBatchSize = extractOptionalPositiveIntegerLessThanOrEqualToMax(
map,
MAX_BATCH_SIZE,
EMBEDDING_MAX_BATCH_SIZE,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
Expand Down Expand Up @@ -106,11 +121,32 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object
dimensionsSetByUser,
maxInputTokens,
dims,
maxBatchSize,
similarityMeasure,
rateLimitSettings
);
}

@Override
public ServiceSettings updateServiceSettings(Map<String, Object> serviceSettings) {
var validationException = new ValidationException();
serviceSettings = new HashMap<>(serviceSettings);

Integer maxBatchSize = extractOptionalPositiveIntegerLessThanOrEqualToMax(
serviceSettings,
MAX_BATCH_SIZE,
EMBEDDING_MAX_BATCH_SIZE,
ModelConfigurations.SERVICE_SETTINGS,
validationException
);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new GoogleVertexAiEmbeddingsServiceSettings(this, maxBatchSize);
}

private final String location;

private final String projectId;
Expand All @@ -119,6 +155,8 @@ public static GoogleVertexAiEmbeddingsServiceSettings fromMap(Map<String, Object

private final Integer dims;

private final Integer maxBatchSize;

private final SimilarityMeasure similarity;
private final Integer maxInputTokens;

Expand All @@ -133,6 +171,7 @@ public GoogleVertexAiEmbeddingsServiceSettings(
Boolean dimensionsSetByUser,
@Nullable Integer maxInputTokens,
@Nullable Integer dims,
@Nullable Integer maxBatchSize,
@Nullable SimilarityMeasure similarity,
@Nullable RateLimitSettings rateLimitSettings
) {
Expand All @@ -142,17 +181,35 @@ public GoogleVertexAiEmbeddingsServiceSettings(
this.dimensionsSetByUser = dimensionsSetByUser;
this.maxInputTokens = maxInputTokens;
this.dims = dims;
this.maxBatchSize = maxBatchSize;
this.similarity = Objects.requireNonNullElse(similarity, SimilarityMeasure.DOT_PRODUCT);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}

public GoogleVertexAiEmbeddingsServiceSettings(GoogleVertexAiEmbeddingsServiceSettings original, @Nullable Integer maxBatchSize) {
this.location = original.location;
this.projectId = original.projectId;
this.modelId = original.modelId;
this.dimensionsSetByUser = original.dimensionsSetByUser;
this.maxInputTokens = original.maxInputTokens;
this.dims = original.dims;
this.maxBatchSize = maxBatchSize != null ? maxBatchSize : original.maxBatchSize;
this.similarity = original.similarity;
this.rateLimitSettings = original.rateLimitSettings;
}

public GoogleVertexAiEmbeddingsServiceSettings(StreamInput in) throws IOException {
this.location = in.readString();
this.projectId = in.readString();
this.modelId = in.readString();
this.dimensionsSetByUser = in.readBoolean();
this.maxInputTokens = in.readOptionalVInt();
this.dims = in.readOptionalVInt();
if (in.getTransportVersion().supports(GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE)) {
this.maxBatchSize = in.readOptionalVInt();
} else {
this.maxBatchSize = null;
}
this.similarity = in.readOptionalEnum(SimilarityMeasure.class);
this.rateLimitSettings = new RateLimitSettings(in);
}
Expand Down Expand Up @@ -189,6 +246,10 @@ public Integer dimensions() {
return dims;
}

public Integer maxBatchSize() {
return maxBatchSize;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
Expand Down Expand Up @@ -228,6 +289,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(dimensionsSetByUser);
out.writeOptionalVInt(maxInputTokens);
out.writeOptionalVInt(dims);
if (out.getTransportVersion().supports(GOOGLE_VERTEX_AI_CONFIGURABLE_MAX_BATCH_SIZE)) {
out.writeOptionalVInt(maxBatchSize);
}
out.writeOptionalEnum(similarity);
rateLimitSettings.writeTo(out);
}
Expand All @@ -246,6 +310,10 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil
builder.field(DIMENSIONS, dims);
}

if (maxBatchSize != null) {
builder.field(MAX_BATCH_SIZE, maxBatchSize);
}

if (similarity != null) {
builder.field(SIMILARITY, similarity);
}
Expand All @@ -264,6 +332,7 @@ public boolean equals(Object object) {
&& Objects.equals(projectId, that.projectId)
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(dims, that.dims)
&& Objects.equals(maxBatchSize, that.maxBatchSize)
&& similarity == that.similarity
&& Objects.equals(maxInputTokens, that.maxInputTokens)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
Expand All @@ -272,6 +341,16 @@ public boolean equals(Object object) {

@Override
public int hashCode() {
return Objects.hash(location, projectId, modelId, dims, similarity, maxInputTokens, rateLimitSettings, dimensionsSetByUser);
return Objects.hash(
location,
projectId,
modelId,
dims,
maxBatchSize,
similarity,
maxInputTokens,
rateLimitSettings,
dimensionsSetByUser
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

Expand Down Expand Up @@ -155,11 +156,13 @@ public void testUpdateNumAllocations() {
);

assertThat("update should create a new instance", updatedInstance, not(equalTo(testInstance)));
assertThat(updatedInstance.getNumAllocations(), equalTo(expectedNumAllocations));
assertThat(updatedInstance.getAdaptiveAllocationsSettings(), nullValue());
assertThat(updatedInstance.getNumThreads(), equalTo(testInstance.getNumThreads()));
assertThat(updatedInstance.getDeploymentId(), equalTo(testInstance.getDeploymentId()));
assertThat(updatedInstance.modelId(), equalTo(testInstance.modelId()));
assertThat(updatedInstance, instanceOf(ElasticsearchInternalServiceSettings.class));
var updatedElasticSearchInternalServiceSettings = (ElasticsearchInternalServiceSettings) updatedInstance;
assertThat(updatedElasticSearchInternalServiceSettings.getNumAllocations(), equalTo(expectedNumAllocations));
assertThat(updatedElasticSearchInternalServiceSettings.getAdaptiveAllocationsSettings(), nullValue());
assertThat(updatedElasticSearchInternalServiceSettings.getNumThreads(), equalTo(testInstance.getNumThreads()));
assertThat(updatedElasticSearchInternalServiceSettings.getDeploymentId(), equalTo(testInstance.getDeploymentId()));
assertThat(updatedElasticSearchInternalServiceSettings.modelId(), equalTo(testInstance.modelId()));

}

Expand All @@ -171,11 +174,13 @@ public void testUpdateAdaptiveAllocations() throws IOException {
);

assertThat("update should create a new instance", updatedInstance, not(equalTo(testInstance)));
assertThat(updatedInstance.getNumAllocations(), nullValue());
assertThat(updatedInstance.getAdaptiveAllocationsSettings(), equalTo(expectedAdaptiveAllocations));
assertThat(updatedInstance.getNumThreads(), equalTo(testInstance.getNumThreads()));
assertThat(updatedInstance.getDeploymentId(), equalTo(testInstance.getDeploymentId()));
assertThat(updatedInstance.modelId(), equalTo(testInstance.modelId()));
assertThat(updatedInstance, instanceOf(ElasticsearchInternalServiceSettings.class));
var updatedElasticSearchInternalServiceSettings = (ElasticsearchInternalServiceSettings) updatedInstance;
assertThat(updatedElasticSearchInternalServiceSettings.getNumAllocations(), nullValue());
assertThat(updatedElasticSearchInternalServiceSettings.getAdaptiveAllocationsSettings(), equalTo(expectedAdaptiveAllocations));
assertThat(updatedElasticSearchInternalServiceSettings.getNumThreads(), equalTo(testInstance.getNumThreads()));
assertThat(updatedElasticSearchInternalServiceSettings.getDeploymentId(), equalTo(testInstance.getDeploymentId()));
assertThat(updatedElasticSearchInternalServiceSettings.modelId(), equalTo(testInstance.modelId()));
}

private static AdaptiveAllocationsSettings adaptiveAllocationSettings(AdaptiveAllocationsSettings base) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public static GoogleVertexAiEmbeddingsModel createModel(
TaskType.TEXT_EMBEDDING,
"service",
uri,
new GoogleVertexAiEmbeddingsServiceSettings(location, projectId, modelId, false, null, null, null, null),
new GoogleVertexAiEmbeddingsServiceSettings(location, projectId, modelId, false, null, null, null, null, null),
new GoogleVertexAiEmbeddingsTaskSettings(Boolean.FALSE, null),
new GoogleVertexAiSecretSettings(new SecureString(serviceAccountJson.toCharArray()))
);
Expand All @@ -120,6 +120,7 @@ public static GoogleVertexAiEmbeddingsModel createModel(
false,
null,
null,
null,
similarityMeasure,
null
),
Expand All @@ -141,6 +142,7 @@ public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullabl
false,
null,
null,
null,
SimilarityMeasure.DOT_PRODUCT,
null
),
Expand All @@ -166,6 +168,7 @@ public static GoogleVertexAiEmbeddingsModel createRandomizedModel(
false,
null,
null,
null,
SimilarityMeasure.DOT_PRODUCT,
null
),
Expand Down
Loading