-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Implement comprehensive top N parameter handling for text similarity reranker #142039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
73cb2cf
a61fe7b
4422cab
e3fef43
e75f55f
fabafe5
422caf5
8a97f42
ab9023c
b01d8e4
e62a551
e3a7a09
5e2ca36
3944ed1
8bcdf0e
47bdb69
1a5e013
202cffb
ad28cab
49e1b67
d9be5aa
de90b23
978062b
e724bc7
206959a
5116dd9
e62445f
1f661b7
25b4367
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| area: Ranking | ||
| issues: [] | ||
| pr: 142039 | ||
| summary: Implement comprehensive top N parameter handling for text similarity reranker | ||
| type: bug |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the "Elastic License | ||
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | ||
| * Public License v 1"; you may not use this file except in compliance with, at | ||
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | ||
| * License v3.0 only", or the "Server Side Public License, v 1". | ||
| */ | ||
|
|
||
| package org.elasticsearch.inference; | ||
|
|
||
| public interface TopNProvider { | ||
| Integer getTopN(); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| import org.elasticsearch.inference.SettingsConfiguration; | ||
| import org.elasticsearch.inference.TaskSettings; | ||
| import org.elasticsearch.inference.TaskType; | ||
| import org.elasticsearch.inference.TopNProvider; | ||
| import org.elasticsearch.inference.UnifiedCompletionRequest; | ||
| import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; | ||
| import org.elasticsearch.rest.RestStatus; | ||
|
|
@@ -40,7 +41,6 @@ | |
|
|
||
| import java.io.IOException; | ||
| import java.util.ArrayList; | ||
| import java.util.Comparator; | ||
| import java.util.EnumSet; | ||
| import java.util.HashMap; | ||
| import java.util.List; | ||
|
|
@@ -194,7 +194,14 @@ private RankedDocsResults makeResults(List<String> input, TestRerankingServiceEx | |
| for (int i = 0; i < totalResults; i++) { | ||
| results.add(new RankedDocsResults.RankedDoc(i, Float.parseFloat(input.get(i)), input.get(i))); | ||
| } | ||
| return new RankedDocsResults(results.stream().sorted(Comparator.reverseOrder()).toList()); | ||
|
|
||
| // RankedDoc's compareTo implementation already sorts by score descending, so we don't need to reverse the sort order | ||
| var sortedResultsStream = results.stream().sorted(); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it turns out that our test reranker service has been sorting docs in reverse relevance this whole time 🫠 . We didn't catch it as an issue because ES sorts the results again (correctly) here. This all happens to work, as long as reranker reranks every doc sent to it. However, if the reranker truncates results (like if a |
||
| if (taskSettings.topN != null) { | ||
| sortedResultsStream = sortedResultsStream.limit(taskSettings.topN); | ||
| } | ||
|
|
||
| return new RankedDocsResults(sortedResultsStream.toList()); | ||
| } catch (NumberFormatException ex) { | ||
| return makeResultFromTextInput(input, taskSettings); | ||
| } | ||
|
|
@@ -216,6 +223,10 @@ private RankedDocsResults makeResultFromTextInput(List<String> input, TestRerank | |
| } | ||
| // Ensure result are sorted by descending score | ||
| results.sort((a, b) -> -Float.compare(a.relevanceScore(), b.relevanceScore())); | ||
| if (taskSettings.topN != null && taskSettings.topN < results.size()) { | ||
| results = results.subList(0, taskSettings.topN); | ||
| } | ||
|
|
||
| return new RankedDocsResults(results); | ||
| } | ||
|
|
||
|
|
@@ -257,9 +268,14 @@ public static InferenceServiceConfiguration get() { | |
| } | ||
| } | ||
|
|
||
| public record TestTaskSettings(boolean shouldFailValidation, boolean useTextLength, float minScore, float resultDiff) | ||
| implements | ||
| TaskSettings { | ||
| public record TestTaskSettings( | ||
| boolean shouldFailValidation, | ||
| boolean useTextLength, | ||
| float minScore, | ||
| float resultDiff, | ||
| Integer topN, | ||
| boolean hideTopN | ||
| ) implements TaskSettings, TopNProvider { | ||
|
|
||
| static final String NAME = "test_reranking_task_settings"; | ||
|
|
||
|
|
@@ -268,6 +284,8 @@ public static TestTaskSettings fromMap(Map<String, Object> map) { | |
| boolean useTextLength = false; | ||
| float minScore = random.nextFloat(-1f, 1f); | ||
| float resultDiff = 0.2f; | ||
| Integer topN = null; | ||
| boolean hideTopN = false; | ||
|
|
||
| if (map.containsKey("should_fail_validation")) { | ||
| shouldFailValidation = Boolean.parseBoolean(map.remove("should_fail_validation").toString()); | ||
|
|
@@ -285,11 +303,19 @@ public static TestTaskSettings fromMap(Map<String, Object> map) { | |
| resultDiff = Float.parseFloat(map.remove("result_diff").toString()); | ||
| } | ||
|
|
||
| return new TestTaskSettings(shouldFailValidation, useTextLength, minScore, resultDiff); | ||
| if (map.containsKey("top_n")) { | ||
Mikep86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| topN = Integer.parseInt(map.remove("top_n").toString()); | ||
| } | ||
|
|
||
| if (map.containsKey("hide_top_n")) { | ||
Mikep86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| hideTopN = Boolean.parseBoolean(map.remove("hide_top_n").toString()); | ||
| } | ||
|
|
||
| return new TestTaskSettings(shouldFailValidation, useTextLength, minScore, resultDiff, topN, hideTopN); | ||
| } | ||
|
|
||
| public TestTaskSettings(StreamInput in) throws IOException { | ||
| this(in.readBoolean(), in.readBoolean(), in.readFloat(), in.readFloat()); | ||
| this(in.readBoolean(), in.readBoolean(), in.readFloat(), in.readFloat(), in.readOptionalInt(), in.readBoolean()); | ||
| } | ||
|
|
||
| @Override | ||
|
|
@@ -303,18 +329,30 @@ public void writeTo(StreamOutput out) throws IOException { | |
| out.writeBoolean(useTextLength); | ||
| out.writeFloat(minScore); | ||
| out.writeFloat(resultDiff); | ||
| out.writeOptionalInt(topN); | ||
Mikep86 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| out.writeBoolean(hideTopN); | ||
| } | ||
|
|
||
| @Override | ||
| public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
| builder.startObject(); | ||
| builder.field("should_fail_validation", shouldFailValidation); | ||
davidkyle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| builder.field("use_text_length", useTextLength); | ||
| builder.field("min_score", minScore); | ||
| builder.field("result_diff", resultDiff); | ||
| if (topN != null) { | ||
| builder.field("top_n", topN); | ||
| } | ||
| builder.field("hide_top_n", hideTopN); | ||
| builder.endObject(); | ||
| return builder; | ||
| } | ||
|
|
||
| @Override | ||
| public Integer getTopN() { | ||
| return hideTopN ? null : topN; | ||
davidkyle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| @Override | ||
| public String getWriteableName() { | ||
| return NAME; | ||
|
|
@@ -332,7 +370,9 @@ public TaskSettings updatedTaskSettings(Map<String, Object> newSettingsMap) { | |
| newSettingsMap.containsKey("should_fail_validation") ? newSettingsObject.shouldFailValidation() : shouldFailValidation, | ||
| newSettingsMap.containsKey("use_text_length") ? newSettingsObject.useTextLength() : useTextLength, | ||
| newSettingsMap.containsKey("min_score") ? newSettingsObject.minScore() : minScore, | ||
| newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff | ||
| newSettingsMap.containsKey("result_diff") ? newSettingsObject.resultDiff() : resultDiff, | ||
| newSettingsMap.containsKey("top_n") ? newSettingsObject.topN() : topN, | ||
| newSettingsMap.containsKey("hide_top_n") ? newSettingsObject.hideTopN() : hideTopN | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,10 +196,7 @@ public RankFeaturePhaseRankCoordinatorContext buildRankFeaturePhaseCoordinatorCo | |
| inferenceId, | ||
| inferenceText, | ||
| minScore, | ||
| failuresAllowed, | ||
| chunkScorerConfig != null | ||
| ? new ChunkScorerConfig(chunkScorerConfig.size, inferenceText, chunkScorerConfig.chunkingSettings()) | ||
| : null | ||
| failuresAllowed | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice cleanup here! |
||
| ); | ||
| } | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.