[ML] Add default Elastic Inference Service chat completion endpoint#120847
[ML] Add default Elastic Inference Service chat completion endpoint#120847jonathan-buttner merged 7 commits intoelastic:mainfrom
Conversation
eda2068 to
665e700
Compare
| import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated; | ||
| import static org.hamcrest.Matchers.equalTo; | ||
|
|
||
| public class InferenceGetServicesIT extends ESRestTestCase { |
There was a problem hiding this comment.
Moved this to BaseMockEISAuthServerTest
| private static final Logger logger = LogManager.getLogger(ElasticInferenceService.class); | ||
| private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); | ||
| private static final String SERVICE_NAME = "Elastic"; | ||
| static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; |
There was a problem hiding this comment.
Model name
| private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION); | ||
| private static final String SERVICE_NAME = "Elastic"; | ||
| static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; | ||
| static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = Strings.format(".%s-elastic", DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); |
There was a problem hiding this comment.
Inference endpoint ID
| ); | ||
| } | ||
|
|
||
| private record AuthorizedContent( |
There was a problem hiding this comment.
Just an aggregation of all the different pieces we need to expose (enabled task types, DefaultConfigId objects, and a list of models).
| if (auth.getEnabledTaskTypes().contains(model.getTaskType()) == false) { | ||
| logger.warn( | ||
| Strings.format( | ||
| "The authorization response included the default model: %s, " |
There was a problem hiding this comment.
In the unlikely chance that the gateway and the definition of the default model have differing task types, we'll enable the model anyway because that's what the gateway said to do.
This would only happen if the authorization response returned something different from how we've set the task type for the model here.
| private Set<String> getEnabledDefaultModelIds(ElasticInferenceServiceAuthorization auth) { | ||
| var enabledModels = auth.getEnabledModels(); | ||
| var enabledDefaultModelIds = new HashSet<>(defaultModels.keySet()); | ||
| enabledDefaultModelIds.retainAll(enabledModels); |
There was a problem hiding this comment.
Return the model ids where there is overlap between what the gateway authorized and the default ones we've defined.
| * This is a helper class for managing the response from {@link ElasticInferenceServiceAuthorizationHandler}. | ||
| */ | ||
| public record ElasticInferenceServiceAuthorization(Map<String, EnumSet<TaskType>> enabledModels) { | ||
| public class ElasticInferenceServiceAuthorization { |
There was a problem hiding this comment.
I refactored this because we need both the authorized task types and the authorized models.
| public record ElasticInferenceServiceAuthorization(Map<String, EnumSet<TaskType>> enabledModels) { | ||
| public class ElasticInferenceServiceAuthorization { | ||
|
|
||
| private final Map<TaskType, Set<String>> taskTypeToModels; |
There was a problem hiding this comment.
This mapping helps when we need to create a new object that's limited to the what the service actually supports. So we can easily grab the models that were authorized for a particular task type.
|
Pinging @elastic/ml-core (Team:ML) |
joshdevins
left a comment
There was a problem hiding this comment.
Had a quick scan only. Minor comments.
| .setting("xpack.security.enabled", "true") | ||
| // Adding both settings unless one feature flag is disabled in a particular environment | ||
| .setting("xpack.inference.elastic.url", mockEISServer::getUrl) | ||
| // TODO remove this once we've removed DEPRECATED_ELASTIC_INFERENCE_SERVICE_FEATURE_FLAG and EIS_GATEWAY_URL |
| "task_types": ["chat"] | ||
| }, | ||
| { | ||
| "model_name": ".elser_model_2", |
There was a problem hiding this comment.
EIS will expose elser-v2. Not sure it matters for this test though.
See: #120981
There was a problem hiding this comment.
I'll update it 👍 the model ID here isn't actually being used but might as well try to align it for the future when we do use it.
💔 Backport failed
You can use sqren/backport to manually backport by running |
💚 All backports created successfully
Questions ?Please refer to the Backport tool documentation |
…lastic#120847) * Starting new auth class implementation * Fixing some tests * Working tests * Refactoring * Addressing feedback and pull main (cherry picked from commit 1fa1ba7)
This PR adds the first iteration model id for the elastic inference service.
Model id:
rainbow-sprinklesThe default endpoint id:
.rainbow-sprinkles-elasticTesting
Without EIS
GET _inference/_allelasticshould not be listed in the responseGET _inference/_services.rainbow-sprinkles-elasticshould not be listed in the responseWith EIS
Get the right certs directory.
Run the gateway:
make TLS_VERIFY_CLIENT_CERTS=false runRun ES:
Retrieve all the default inference endpoints
Retrieving all the available services for sparse embedding