[ML] Integrate SageMaker with OpenAI Embeddings#126856
[ML] Integrate SageMaker with OpenAI Embeddings#126856prwhelan merged 21 commits intoelastic:mainfrom
Conversation
|
Hi @prwhelan, I've created a changelog YAML for you. |
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
jonathan-buttner
left a comment
There was a problem hiding this comment.
Looking good! Just left a few thoughts.
...ence/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java
Show resolved
Hide resolved
...ence/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/SageMakerClient.java
Outdated
Show resolved
Hide resolved
...ava/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerServiceSettings.java
Show resolved
Hide resolved
| return builder.endObject(); | ||
| } | ||
|
|
||
| private static <T> void optionalField(String name, T value, XContentBuilder builder) throws IOException { |
There was a problem hiding this comment.
Nice, might be helpful to have this in a utility class somewhere eventually because we have to do stuff like this a lot.
...c/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchema.java
Show resolved
Hide resolved
...rg/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredServiceSchema.java
Outdated
Show resolved
Hide resolved
...rg/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStreamSchemaPayload.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
...ence/src/test/java/org/elasticsearch/xpack/inference/services/InferenceSettingsTestCase.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
...asticsearch/xpack/inference/services/sagemaker/schema/openai/OpenAiTextEmbeddingPayload.java
Outdated
Show resolved
Hide resolved
|
Pinging @elastic/ml-core (Team:ML) |
jonathan-buttner
left a comment
There was a problem hiding this comment.
Looks good! Just a reminder to add docs in the elasticsearch-specification repo.
| return Collections.unmodifiableMap(configurationMap); | ||
| }); | ||
| new LazyInitializable<>( | ||
| () -> configuration(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION)).collect( |
There was a problem hiding this comment.
nit: Would Map.of() work instead of using a stream?
There was a problem hiding this comment.
Oh I see we're combining multiple streams in a separate place 👍
| } else { | ||
| ExceptionsHelper.maybeError(t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread); | ||
| log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker."); | ||
| listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.")); |
There was a problem hiding this comment.
| listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.")); | |
| listener.onFailure(new RuntimeException("Unknown failure calling SageMaker.", t)); |
| public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) { | ||
| if (holder.compareAndSet(null, Tuple.tuple(null, subscriber)) == false) { | ||
| log.debug("Subscriber connecting to publisher."); | ||
| var publisher = holder.getAndSet(null).v1(); |
There was a problem hiding this comment.
Other implementations of this method call onError() if a subscriber is already set, should this do the same?
| Map<String, Object> config, | ||
| ActionListener<Model> parsedModelListener | ||
| ) { | ||
| ActionListener.completeWith(parsedModelListener, () -> modelBuilder.fromRequest(modelId, taskType, NAME, config)); |
|
|
||
| public class SageMakerService implements InferenceService { | ||
| public static final String NAME = "sagemaker"; | ||
| private static final int DEFAULT_BATCH_SIZE = 2048; |
There was a problem hiding this comment.
Seems like is a big number. 2048 may be an optimal size for SageMaker but a batch this size would use quite a lot of memory and isn't sympathetic with how the inference API works
| Map.entry( | ||
| API, | ||
| new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The API format that your SageMaker Endpoint expects.") | ||
| .setLabel("Api") |
There was a problem hiding this comment.
| .setLabel("Api") | |
| .setLabel("API") |
| public final void testXContentRoundTrip() throws IOException { | ||
| var instance = createTestInstance(); | ||
| var instanceAsMap = toMap(instance); | ||
| var roundTripInstance = fromMutableMap(new HashMap<>(instanceAsMap)); |
💔 Backport failed
You can use sqren/backport to manually backport by running |
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Integrating with SageMaker. Current design: - SageMaker accepts any byte payload, which can be text, csv, or json. `api` represents the structure of the payload that we will send, for example `openai`, `elastic`, `common`, probably `cohere` or `huggingface` as well. - `api` implementations are extensions of `SageMakerSchemaPayload`, which supports: - "extra" service and task settings specific to the payload structure, so `cohere` would require `embedding_type` and `openai` would require `dimensions` in the `service_settings` - conversion logic from model, service settings, task settings, and input to `SdkBytes` - conversion logic from responding `SdkBytes` to `InferenceServiceResults` - Everything else is tunneling, there are a number of base `service_settings` and `task_settings` that are independent of the api format that we will store and set - We let the SDK do the bulk of the work in terms of connection details, rate limiting, retries, etc.
Integrating with SageMaker.
Current design:
apirepresents the structure of the payload that we will send, for exampleopenai,elastic,common, probablycohereorhuggingfaceas well.apiimplementations are extensions ofSageMakerSchemaPayload, which supports:coherewould requireembedding_typeandopenaiwould requiredimensionsin theservice_settingsSdkBytesSdkBytestoInferenceServiceResultsservice_settingsandtask_settingsthat are independent of the api format that we will store and set