Skip to content

Commit 3c29593

Browse files
authored
[ES|QL] COMPLETION command analysis. (#126677)
* [ES|QL] COMPLETION command analysis. * Moving prompt type test in postAnalysisVerification * Test lint.
1 parent 55a6624 commit 3c29593

File tree

4 files changed

+165
-24
lines changed

4 files changed

+165
-24
lines changed

x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,10 @@ public static <T> T singleValue(Collection<T> collection) {
884884
return collection.iterator().next();
885885
}
886886

887+
public static Attribute getAttributeByName(Collection<Attribute> attributes, String name) {
888+
return attributes.stream().filter(attr -> attr.name().equals(name)).findAny().orElse(null);
889+
}
890+
887891
public static Map<String, Object> jsonEntityToMap(HttpEntity entity) throws IOException {
888892
return entityToMap(entity, XContentType.JSON);
889893
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
import org.elasticsearch.xpack.esql.plan.logical.Rename;
9292
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
9393
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
94+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
9495
import org.elasticsearch.xpack.esql.plan.logical.inference.InferencePlan;
9596
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
9697
import org.elasticsearch.xpack.esql.plan.logical.join.Join;
@@ -490,6 +491,10 @@ protected LogicalPlan rule(LogicalPlan plan, AnalyzerContext context) {
490491
return resolveAggregate(aggregate, childrenOutput);
491492
}
492493

494+
if (plan instanceof Completion c) {
495+
return resolveCompletion(c, childrenOutput);
496+
}
497+
493498
if (plan instanceof Drop d) {
494499
return resolveDrop(d, childrenOutput);
495500
}
@@ -600,6 +605,21 @@ private Aggregate resolveAggregate(Aggregate aggregate, List<Attribute> children
600605
return aggregate;
601606
}
602607

608+
private LogicalPlan resolveCompletion(Completion p, List<Attribute> childrenOutput) {
609+
Attribute targetField = p.targetField();
610+
Expression prompt = p.prompt();
611+
612+
if (targetField instanceof UnresolvedAttribute ua) {
613+
targetField = new ReferenceAttribute(ua.source(), ua.name(), TEXT);
614+
}
615+
616+
if (prompt.resolved() == false) {
617+
prompt = prompt.transformUp(UnresolvedAttribute.class, ua -> maybeResolveAttribute(ua, childrenOutput));
618+
}
619+
620+
return new Completion(p.source(), p.child(), p.inferenceId(), prompt, targetField);
621+
}
622+
603623
private LogicalPlan resolveMvExpand(MvExpand p, List<Attribute> childrenOutput) {
604624
if (p.target() instanceof UnresolvedAttribute ua) {
605625
Attribute resolved = maybeResolveAttribute(ua, childrenOutput);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/inference/Completion.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisVerificationAware;
15+
import org.elasticsearch.xpack.esql.common.Failures;
1416
import org.elasticsearch.xpack.esql.core.expression.Attribute;
1517
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
1618
import org.elasticsearch.xpack.esql.core.expression.Expression;
1719
import org.elasticsearch.xpack.esql.core.expression.NameId;
1820
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
1921
import org.elasticsearch.xpack.esql.core.tree.Source;
22+
import org.elasticsearch.xpack.esql.core.type.DataType;
2023
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
2124
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
2225
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
@@ -26,9 +29,15 @@
2629
import java.util.List;
2730
import java.util.Objects;
2831

32+
import static org.elasticsearch.xpack.esql.common.Failure.fail;
33+
import static org.elasticsearch.xpack.esql.core.type.DataType.TEXT;
2934
import static org.elasticsearch.xpack.esql.expression.NamedExpressions.mergeOutputAttributes;
3035

31-
public class Completion extends InferencePlan<Completion> implements GeneratingPlan<Completion>, SortAgnostic {
36+
public class Completion extends InferencePlan<Completion>
37+
implements
38+
GeneratingPlan<Completion>,
39+
SortAgnostic,
40+
PostAnalysisVerificationAware {
3241

3342
public static final String DEFAULT_OUTPUT_FIELD_NAME = "completion";
3443

@@ -130,6 +139,13 @@ public boolean expressionsResolved() {
130139
return super.expressionsResolved() && prompt.resolved();
131140
}
132141

142+
@Override
143+
public void postAnalysisVerification(Failures failures) {
144+
if (prompt.resolved() && DataType.isString(prompt.dataType()) == false) {
145+
failures.add(fail(prompt, "prompt must be of type [{}] but is [{}]", TEXT.typeName(), prompt.dataType().typeName()));
146+
}
147+
}
148+
133149
@Override
134150
protected NodeInfo<? extends LogicalPlan> info() {
135151
return NodeInfo.create(this, Completion::new, child(), inferenceId(), prompt, targetField);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

Lines changed: 124 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.elasticsearch.xpack.esql.expression.function.fulltext.Match;
4747
import org.elasticsearch.xpack.esql.expression.function.fulltext.MatchOperator;
4848
import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString;
49+
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
4950
import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
5051
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
5152
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
@@ -71,6 +72,7 @@
7172
import org.elasticsearch.xpack.esql.plan.logical.Row;
7273
import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval;
7374
import org.elasticsearch.xpack.esql.plan.logical.UnresolvedRelation;
75+
import org.elasticsearch.xpack.esql.plan.logical.inference.Completion;
7476
import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank;
7577
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
7678
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
@@ -92,9 +94,11 @@
9294
import static org.elasticsearch.xpack.esql.EsqlTestUtils.as;
9395
import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration;
9496
import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyInferenceResolution;
97+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.getAttributeByName;
9598
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsConstant;
9699
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsIdentifier;
97100
import static org.elasticsearch.xpack.esql.EsqlTestUtils.paramAsPattern;
101+
import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute;
98102
import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning;
99103
import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS;
100104
import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze;
@@ -3460,7 +3464,7 @@ public void testResolveRerankInferenceId() {
34603464

34613465
{
34623466
LogicalPlan plan = analyze(
3463-
" FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
3467+
"FROM books METADATA _score | RERANK \"italian food recipe\" ON title WITH `reranking-inference-id`",
34643468
"mapping-books.json"
34653469
);
34663470
Rerank rerank = as(as(plan, Limit.class).child(), Rerank.class);
@@ -3530,16 +3534,13 @@ public void testResolveRerankFields() {
35303534
Filter filter = as(drop.child(), Filter.class);
35313535
EsRelation relation = as(filter.child(), EsRelation.class);
35323536

3533-
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
3534-
assertThat(titleAttribute, notNullValue());
3537+
Attribute titleAttribute = getAttributeByName(relation.output(), "title");
3538+
assertThat(getAttributeByName(relation.output(), "title"), notNullValue());
35353539

35363540
assertThat(rerank.queryText(), equalTo(string("italian food recipe")));
35373541
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
35383542
assertThat(rerank.rerankFields(), equalTo(List.of(alias("title", titleAttribute))));
3539-
assertThat(
3540-
rerank.scoreAttribute(),
3541-
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
3542-
);
3543+
assertThat(rerank.scoreAttribute(), equalTo(getAttributeByName(relation.output(), MetadataAttribute.SCORE)));
35433544
}
35443545

35453546
{
@@ -3559,15 +3560,11 @@ public void testResolveRerankFields() {
35593560
assertThat(rerank.inferenceId(), equalTo(string("reranking-inference-id")));
35603561

35613562
assertThat(rerank.rerankFields(), hasSize(3));
3562-
Attribute titleAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("title")).findFirst().get();
3563+
Attribute titleAttribute = getAttributeByName(relation.output(), "title");
35633564
assertThat(titleAttribute, notNullValue());
35643565
assertThat(rerank.rerankFields().get(0), equalTo(alias("title", titleAttribute)));
35653566

3566-
Attribute descriptionAttribute = relation.output()
3567-
.stream()
3568-
.filter(attribute -> attribute.name().equals("description"))
3569-
.findFirst()
3570-
.get();
3567+
Attribute descriptionAttribute = getAttributeByName(relation.output(), "description");
35713568
assertThat(descriptionAttribute, notNullValue());
35723569
Alias descriptionAlias = rerank.rerankFields().get(1);
35733570
assertThat(descriptionAlias.name(), equalTo("description"));
@@ -3576,13 +3573,11 @@ public void testResolveRerankFields() {
35763573
equalTo(List.of(descriptionAttribute, literal(0), literal(100)))
35773574
);
35783575

3579-
Attribute yearAttribute = relation.output().stream().filter(attribute -> attribute.name().equals("year")).findFirst().get();
3576+
Attribute yearAttribute = getAttributeByName(relation.output(), "year");
35803577
assertThat(yearAttribute, notNullValue());
35813578
assertThat(rerank.rerankFields().get(2), equalTo(alias("yearRenamed", yearAttribute)));
3582-
assertThat(
3583-
rerank.scoreAttribute(),
3584-
equalTo(relation.output().stream().filter(attr -> attr.name().equals(MetadataAttribute.SCORE)).findFirst().get())
3585-
);
3579+
3580+
assertThat(rerank.scoreAttribute(), equalTo(getAttributeByName(relation.output(), MetadataAttribute.SCORE)));
35863581
}
35873582

35883583
{
@@ -3614,11 +3609,7 @@ public void testResolveRerankScoreField() {
36143609
Filter filter = as(rerank.child(), Filter.class);
36153610
EsRelation relation = as(filter.child(), EsRelation.class);
36163611

3617-
Attribute metadataScoreAttribute = relation.output()
3618-
.stream()
3619-
.filter(attr -> attr.name().equals(MetadataAttribute.SCORE))
3620-
.findFirst()
3621-
.get();
3612+
Attribute metadataScoreAttribute = getAttributeByName(relation.output(), MetadataAttribute.SCORE);
36223613
assertThat(rerank.scoreAttribute(), equalTo(metadataScoreAttribute));
36233614
assertThat(rerank.output(), hasItem(metadataScoreAttribute));
36243615
}
@@ -3642,6 +3633,116 @@ public void testResolveRerankScoreField() {
36423633
}
36433634
}
36443635

3636+
public void testResolveCompletionInferenceId() {
3637+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3638+
3639+
LogicalPlan plan = analyze("""
3640+
FROM books METADATA _score
3641+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id`
3642+
""", "mapping-books.json");
3643+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3644+
assertThat(completion.inferenceId(), equalTo(string("completion-inference-id")));
3645+
}
3646+
3647+
public void testResolveCompletionInferenceIdInvalidTaskType() {
3648+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3649+
3650+
assertError(
3651+
"""
3652+
FROM books METADATA _score
3653+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `reranking-inference-id`
3654+
""",
3655+
"mapping-books.json",
3656+
new QueryParams(),
3657+
"cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
3658+
+ " Only inference endpoints with the task type [completion] are supported"
3659+
);
3660+
}
3661+
3662+
public void testResolveCompletionInferenceMissingInferenceId() {
3663+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3664+
3665+
assertError("""
3666+
FROM books METADATA _score
3667+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `unknown-inference-id`
3668+
""", "mapping-books.json", new QueryParams(), "unresolved inference [unknown-inference-id]");
3669+
}
3670+
3671+
public void testResolveCompletionInferenceIdResolutionError() {
3672+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3673+
3674+
assertError("""
3675+
FROM books METADATA _score
3676+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `error-inference-id`
3677+
""", "mapping-books.json", new QueryParams(), "error with inference resolution");
3678+
}
3679+
3680+
public void testResolveCompletionTargetField() {
3681+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3682+
3683+
LogicalPlan plan = analyze("""
3684+
FROM books METADATA _score
3685+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` AS translation
3686+
""", "mapping-books.json");
3687+
3688+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3689+
assertThat(completion.targetField(), equalTo(referenceAttribute("translation", DataType.TEXT)));
3690+
}
3691+
3692+
public void testResolveCompletionDefaultTargetField() {
3693+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3694+
3695+
LogicalPlan plan = analyze("""
3696+
FROM books METADATA _score
3697+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id`
3698+
""", "mapping-books.json");
3699+
3700+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3701+
assertThat(completion.targetField(), equalTo(referenceAttribute("completion", DataType.TEXT)));
3702+
}
3703+
3704+
public void testResolveCompletionPrompt() {
3705+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3706+
3707+
LogicalPlan plan = analyze("""
3708+
FROM books METADATA _score
3709+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id`
3710+
""", "mapping-books.json");
3711+
3712+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3713+
EsRelation esRelation = as(completion.child(), EsRelation.class);
3714+
3715+
assertThat(
3716+
as(completion.prompt(), Concat.class).children(),
3717+
equalTo(List.of(string("Translate the following text in French\n"), getAttributeByName(esRelation.output(), "description")))
3718+
);
3719+
}
3720+
3721+
public void testResolveCompletionPromptInvalidType() {
3722+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3723+
3724+
assertError("""
3725+
FROM books METADATA _score
3726+
| COMPLETION LENGTH(description) WITH `completion-inference-id`
3727+
""", "mapping-books.json", new QueryParams(), "prompt must be of type [text] but is [integer]");
3728+
}
3729+
3730+
public void testResolveCompletionOutputField() {
3731+
assumeTrue("Requires COMPLETION command", EsqlCapabilities.Cap.COMPLETION.isEnabled());
3732+
3733+
LogicalPlan plan = analyze("""
3734+
FROM books METADATA _score
3735+
| COMPLETION CONCAT("Translate the following text in French\\n", description) WITH `completion-inference-id` AS description
3736+
""", "mapping-books.json");
3737+
3738+
Completion completion = as(as(plan, Limit.class).child(), Completion.class);
3739+
assertThat(completion.targetField(), equalTo(referenceAttribute("description", DataType.TEXT)));
3740+
3741+
EsRelation esRelation = as(completion.child(), EsRelation.class);
3742+
assertThat(getAttributeByName(completion.output(), "description"), equalTo(completion.targetField()));
3743+
assertThat(getAttributeByName(esRelation.output(), "description"), not(equalTo(completion.targetField())));
3744+
}
3745+
36453746
@Override
36463747
protected IndexAnalyzers createDefaultIndexAnalyzers() {
36473748
return super.createDefaultIndexAnalyzers();

0 commit comments

Comments
 (0)