[ML] Adding custom headers support openai text embeddings#134960
[ML] Adding custom headers support openai text embeddings#134960jonathan-buttner merged 18 commits intoelastic:mainfrom
Conversation
|
Hi @jonathan-buttner, I've created a changelog YAML for you. |
…i-headers-embedding
…tner/elasticsearch into ml-openai-headers-embedding
|
Pinging @elastic/ml-core (Team:ML) |
| String user; | ||
|
|
||
| public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { | ||
| if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { |
There was a problem hiding this comment.
It might be overcomplicating things, but it should be possible to move the readTaskSettingsFromStream() and writeTo() implementations into the base class as well, by introducing an abstract method like abstract boolean shouldReadAdditionalString(TransportVersion version) which always returns false for OpenAiChatCompletionTaskSettings and checks the transport version for OpenAiEmbeddingsTaskSettings, and another abstract method that returns the transport version in which the headers were introduced for each class and which is used to determine whether to read/write the headers. The final result would look something like this for the reading side:
public OpenAiTaskSettings(StreamInput in) throws IOException {
String user;
if (shouldReadAdditionalString(in.getTransportVersion())) {
var discard = in.readString();
user = in.readOptionalString();
} else {
user = in.readOptionalString();
}
Map<String, String> headers;
if (in.getTransportVersion().onOrAfter(headersIntroducedVersion())) {
headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString);
} else {
headers = null;
}
taskSettings = user == null && headers == null ? EMPTY_SETTINGS : new Settings(user, headers);
}
There was a problem hiding this comment.
Thanks for the suggestion. I agree that these changes would help with reducing the duplication. I think for this one I'd rather leave the transport version logic in the individual classes that need it. I think it might be a bit clearer when looking at the individual class as to what has changed between versions. If we end up making more transport version dependent changes to these two classes we can revisit pulling the logic into the base class.
There was a problem hiding this comment.
I'll add the changes for the empty settings though 👍
| throw validationException; | ||
| } | ||
|
|
||
| return new Settings(user, stringHeaders); |
There was a problem hiding this comment.
Using a method like the one below, we can ensure that any time the settings would be empty, we use EMPTY_SETTINGS instead of constructing a new, empty Settings object. This would allow us to make the isEmpty() method a simple == comparison with EMPTY_SETTINGS and also prevent some unnecessary object allocations:
private static Settings getSettingsCheckingForEmpty(String user, Map<String, String> stringHeaders) {
if (user == null && (stringHeaders == null || stringHeaders.isEmpty())) {
return EMPTY_SETTINGS;
} else {
return new Settings(user, stringHeaders);
}
}
This method could also be used in the constructor below this.
I think this might lead to a subtle inconsistency with how we treat empty maps though, since passing a null user and an empty map to the constructor currently results in a Settings object that is not equal to EMPTY_SETTINGS, but still returns true from isEmpty(). Do we need to be able to differentiate between the "empty map" and "null map" versions of Settings? If so, then I don't think we can use EMPTY_SETTINGS when the user is null and the map is non-null but empty, but we can still use it when both are null.
There was a problem hiding this comment.
Do we need to be able to differentiate between the "empty map" and "null map" versions of Settings?
I think we're going to have a serialization issue if we internally convert an empty map to a null map because the testing logic will try to write the empty map and ensure that it will be read as an empty map. I suppose we could only do the conversion in the fromMap() function 🤔 . I suspect that it'll be pretty rare that folks include an empty headers field. And if they did include an empty headers field in the PUT request, when we create the inference endpoint it wouldn't write headers in the toXContent(). So I think to make things easier with the serialization tests I'm going to treat a null map and an empty map as separate things.
| var randomSettings = create(randomBoolean() ? null : "username", randomBoolean() ? null : Map.of("key", "value")); | ||
| var stringRep = Strings.toString(randomSettings); | ||
|
|
||
| assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); |
There was a problem hiding this comment.
Strictly speaking, this test is not testing isEmpty() but rather whether isEmpty() and toXContent() agree. If there were bugs in both methods that just so happened to agree with each other, the test would still pass despite isEmpty() being incorrect.
Since there aren't many permutations of possible arguments to the create() method for this test, it might be worth explicitly testing all the combinations and making sure that isEmpty() returns the value expected:
public void testIsEmpty() {
var bothNull = create(null, null);
assertThat(bothNull.isEmpty(), is(true));
var nullUserEmptyHeaders = create(null, Map.of());
assertThat(nullUserEmptyHeaders.isEmpty(), is(true));
var nullHeaders = create("user", null);
assertThat(nullHeaders.isEmpty(), is(false));
var nullUser = create(null, Map.of("K", "v"));
assertThat(nullUser.isEmpty(), is(false));
var neitherNull = create("user", Map.of("K", "v"));
assertThat(neitherNull.isEmpty(), is(false));
}
It's a shame that the parameterized testing framework we use doesn't support per-method parameterization, because that would make this a bit cleaner.
| newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user()); | ||
| } | ||
|
|
||
| if (newSettings.headers() != null && newSettings.headers().isEmpty() == false) { |
There was a problem hiding this comment.
Due to the implementation of createRandom(), the headers map is never empty, so we don't have coverage of the case where the headers in newSettings are an empty map. If I force that case by modifying the test, it fails due to expecting the updated settings to be an empty map, when they are in fact the original value from initialSettings. I'm not sure whether we expect an empty map in newSettings to overwrite the original value or not, but it would be good to explicitly test that behaviour.
There was a problem hiding this comment.
Good point, I added a few tests for this and updated the createRandom to also generate empty headers.
|
|
||
| protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers); | ||
|
|
||
| protected abstract T create(Map<String, Object> map); |
There was a problem hiding this comment.
To avoid confusion, it might be better to rename this createFromMap()
| assertThat(settings.headers(), is(Map.of("key", "value"))); | ||
| } | ||
|
|
||
| public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() { |
There was a problem hiding this comment.
Could we also have a test for the cases where the object stored at the HEADERS key in the map is an empty map, a map with only null values, and for when it's a map that doesn't contain Strings?
| public void testOf_KeepsOriginalValuesWithOverridesAreNull() { | ||
| var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); | ||
|
|
||
| assertThat(taskSettings.updatedTaskSettings(Map.of()), is(taskSettings)); | ||
| } | ||
|
|
||
| public void testOf_UsesOverriddenSettings() { | ||
| var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); | ||
|
|
||
| assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.USER, "user2")), is(create("user2", null))); | ||
| } | ||
|
|
||
| public void testOf_UsesOverriddenSettings_ForHeaders() { | ||
| var user = "user"; | ||
| var taskSettings = create(new HashMap<>(Map.of(OpenAiServiceFields.USER, user))); | ||
|
|
||
| var headers = Map.of("key", "value"); | ||
| assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.HEADERS, headers)), is(create(user, headers))); | ||
| } |
There was a problem hiding this comment.
There is no of() method on OpenAiTaskSettings, so these tests should either be renamed, or just removed, since I think they cover the same behaviour as the testUpdatedTaskSettings() test, just in a slightly different way.
There was a problem hiding this comment.
I'll rename them 👍
| var stringRep = Strings.toString(randomSettings); | ||
| assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); | ||
| } | ||
| public static Map<String, Object> getTaskSettingsMap(@Nullable String user) { |
There was a problem hiding this comment.
This method seems to only be called in OpenAiServiceTests, so could it be moved there and made non-static?
There was a problem hiding this comment.
I think the reason it lives in OpenAiEmbeddingsTaskSettingsTests is because it's creating a map that is valid for the embeddings task settings. IMHO if I were writing additional tests that needed to construct a valid map for embedding task settings, I'd probably look in the embeddings task settings tests rather than in the service tests. Since it's only used in the service tests I can move it if you'd like though.
made non-static?
Just curious why you're advocating for this to be non-static since it doesn't reference any member variables?
There was a problem hiding this comment.
I see, the explanation for it being in OpenAiEmbeddingsTaskSettingsTests makes sense, although it is a little strange that it's not actually used there, even though most of the tests in that class could definitely be using it.
For making the method non-static, that was a brain fart, I meant private, although that would also defeat the purpose of this being a helper/utility method used by other classes.
As an aside, I just noticed that in OpenAiServiceTests.testParseRequestConfig_CreatesAnOpenAiChatCompletionsModel() we call this same static method to create task settings for a completion task, which is incorrect. Right now it happens to work because both the embeddings task settings and the completion task settings have a user field, but the logic for creating a completion-specific map should belong in a completion-specific class since their implementations now differ with the addition of the headers field.
There was a problem hiding this comment.
now differ with the addition of the headers field.
Yeah, with this PR the completions and embeddings logic will be the same again since they'll both have headers. I'll move the map creation into the base test class 👍
There was a problem hiding this comment.
Actually there already is one there: getOpenAiTaskSettingsMap, I'll remove this one.
| * <a href="https://platform.openai.com/docs/api-reference/embeddings/create">see the openai docs for more details</a> | ||
| */ | ||
| public class OpenAiEmbeddingsTaskSettings implements TaskSettings { | ||
| public class OpenAiEmbeddingsTaskSettings extends OpenAiTaskSettings<OpenAiEmbeddingsTaskSettings> implements TaskSettings { |
There was a problem hiding this comment.
Since OpenAiTaskSettings implements TaskSettings this class doesn't need to also implement it.
| assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); | ||
| assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); | ||
| assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); | ||
| assertThat(httpPost.getLastHeader("key").getValue(), is("value")); |
There was a problem hiding this comment.
Could the key and value Strings be extracted to variables?
|
|
||
| public abstract class OpenAiTaskSettingsTests<T extends OpenAiTaskSettings<T>> extends AbstractBWCWireSerializationTestCase<T> { | ||
|
|
||
| private enum HeadersDefinition { |
There was a problem hiding this comment.
Nice, I like this solution.
…tner/elasticsearch into ml-openai-headers-embedding
…i-headers-embedding
…tner/elasticsearch into ml-openai-headers-embedding
…-dls * upstream/main: Bump FLEET_AGENTS_MAPPINGS_VERSION so the new mapping applies on upgrades (elastic#134957) [ML] Adding custom headers support openai text embeddings (elastic#134960) Fix systemd notify to use a shared arena (elastic#135235)
This PR adds custom headers support for text embedding for openai. This is the counter part to this PR: #134504
Example request