Add NVIDIA support to Inference Plugin#132388
Conversation
…ation # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
…ation # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
… feature/nvidia-integration
…ation # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
…ation # Conflicts: # server/src/main/resources/transport/upper_bounds/9.3.csv
…ation # Conflicts: # server/src/main/resources/transport/upper_bounds/9.3.csv
… improved configuration validation
|
Changes proposed in #132388 (comment) are done with improvements and reusage of existing code. Thanks. |
…ation # Conflicts: # server/src/main/resources/transport/upper_bounds/9.3.csv
|
Hi @DonalEvans |
…ation # Conflicts: # server/src/main/resources/transport/upper_bounds/9.3.csv
The AzureOpenAI, Llama and OpenShiftAI integrations all also extend OpenAiResponseHandler for handling text embedding responses, so I wonder if there is a similar issue with those integrations, or with the OpenAI integration itself, for text embeddings. I don't have accounts with any of those providers, so if you do have access to them, would you be able to check, please? |
DonalEvans
left a comment
There was a problem hiding this comment.
Sorry for the late request for this change, but would it be possible to convert the unit tests added in this PR that are extending AbstractWireSerializingTestCase to extend AbstractBWCWireSerializationTestCase instead? While there are no backwards compatibility concerns with the classes now since they're brand new, it's good to have the tests set up to catch any that might be introduced in future. The change should be simple, just implementing the mutateInstanceForVersion() method to return the instance unchanged in each test class.
| protected void checkForFailureStatusCode(Request request, HttpResult result) throws RetryException { | ||
| super.checkForFailureStatusCode(request, result); | ||
| } |
There was a problem hiding this comment.
Does this method need to be overridden? It just calls the super method without any additional logic. I think it would be better to make the super method public instead of overriding the method to give access to it. It's also an option to move the OpenAiResponseHandler class into the org.elasticsearch.xpack.inference.common package, similar to what was done with the Truncation enum, since it's used by 9 different integrations at this point.
There was a problem hiding this comment.
OpenAiResponseHandler is way too widely used for me to be comfortable moving it in this integration. I'd do it in separate PR. Made the change to accessibility of checkForFailureStatusCode.
| "param": null, | ||
| "code": null | ||
| } | ||
| "error": "Input length 18432 exceeds maximum allowed token size 8192" |
There was a problem hiding this comment.
For this specific test case, where we're checking that a 413 status leads to truncating the input, it might be better to use an error message that doesn't match the one we check for, to confirm that it's the status code that causes the truncation rather than the error message.
There was a problem hiding this comment.
Fixed.
| "param": null, | ||
| "code": null | ||
| } | ||
| "error": "Input length 18432 exceeds maximum allowed token size 8192" |
There was a problem hiding this comment.
We should be testing both that text embedding requests get truncated when they see this message and that completion requests get truncated when they get a 413 status or see the "Please reduce your prompt; or completion length." message. Could you add tests for the latter two cases please?
There was a problem hiding this comment.
The issue is that we don't perform truncation logic for any chat completion requests. For any integration, including Nvidia.
Meaning if 413 error or 400 with appropriate message is received - request is retried 3 times without changing the input and then if same errors are received - error is returned to the customer.
Truncating completion requests seems off to me, because it would change the meaning of the input, but retrying them when we know that model is not capable of returning positive response is not good either.
I added logic that would throw error right away without retries for completions in case ContentTooLarge error is received. That would make more sense.
Let me know your opinion on that.
| new NvidiaEmbeddingsServiceSettings( | ||
| MODEL_VALUE, | ||
| ServiceUtils.createOptionalUri(null), | ||
| createOptionalUri(null), |
There was a problem hiding this comment.
Rather than using null as the URI here, it might be better to pass in the expected default URI. We end up asserting the same thing in either case, but it makes the test clearer in terms of what the expected value actually is instead of having it hidden inside the logic in the constructor.
There was a problem hiding this comment.
Good thinking. Changed to default value.
| private static final InputType INPUT_TYPE_EXPEDIA_VALUE = InputType.INGEST; | ||
| private static final Truncation TRUNCATE_EXPEDIA_VALUE = Truncation.START; |
There was a problem hiding this comment.
Some more "EXPEDIA" instead of "ELASTIC" here.
There was a problem hiding this comment.
Thanks. Fixed now.
…ation # Conflicts: # server/src/main/resources/transport/upper_bounds/9.3.csv
…lizationTestCase for better backward compatibility handling
Done. Now embeddings tests are extending AbstractBWCWireSerializationTestCase instead of AbstractWireSerializingTestCase |
Unfortunately for OpenShift AI we won't be able to test because the environment was taken away from us and now is used for other purposes. Will check the other providers. |
|
Hello @DonalEvans |
|
The failures in the serverless checks are due to being out of date with the base branch, merging main should resolve them. Once all the checks are passing, I'll merge this PR. |
…ation # Conflicts: # server/src/main/resources/transport/upper_bounds/9.3.csv
|
@DonalEvans I added specification and merged master branch. |
Creation of new NVIDIA inference provider integration allowing:
tasks to be executed as part of inference API with
nvidiaprovider.Changes were tested locally against next models:
Useful doc links:
https://docs.api.nvidia.com/nim/reference/llm-apis
https://docs.api.nvidia.com/nim/reference/retrieval-apis
https://docs.api.nvidia.com/nim/reference/nvidia-llama-3_2-nemoretriever-500m-rerank-v2-infer
Model ID is mandatory because it is used to determine which model to use.
URL is optional because there are default values to access endpoints for different task types
Most of Embeddings models require input_type parameter. It can be provided in task_settings with truncate parameter.
During the chat_completion testing response for function call didn't return any tool_call usage info. If one of the models returns it - it is going to be handled by OpenAI logic.
EMBEDDINGS
Create embeddings endpoint (mandatory input_type)
Create embeddings endpoint (Default URL)
Perform embeddings (input_type is taken from task_settings on endpoint creation)
Perform embeddings (with task_settings)
Create embeddings endpoint (without task_settings with input_type)
Perform embeddings (no task_settings)
Create embeddings endpoint (Not Found error)
COMPLETION
Create completion endpoint
Perform non-streaming completion
Perform streaming completion
Create completion endpoint (Not Found error)
Create completion endpoint (Default URL)
CHAT COMPLETION
Create chat completion endpoint
Perform basic chat completion
Create chat completion endpoint (Not Found error)
Create chat completion endpoint (Default URL)
RERANK
Create rerank endpoint
Perform rerank
Create rerank endpoint (Not Found error)
Create rerank endpoint (Default URL)
Perform rerank (With default URL)
gradle check?