46
46
import org .elasticsearch .xpack .esql .expression .function .fulltext .Match ;
47
47
import org .elasticsearch .xpack .esql .expression .function .fulltext .MatchOperator ;
48
48
import org .elasticsearch .xpack .esql .expression .function .fulltext .QueryString ;
49
+ import org .elasticsearch .xpack .esql .expression .function .scalar .string .Concat ;
49
50
import org .elasticsearch .xpack .esql .expression .function .scalar .string .Substring ;
50
51
import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .Equals ;
51
52
import org .elasticsearch .xpack .esql .expression .predicate .operator .comparison .GreaterThan ;
71
72
import org .elasticsearch .xpack .esql .plan .logical .Row ;
72
73
import org .elasticsearch .xpack .esql .plan .logical .RrfScoreEval ;
73
74
import org .elasticsearch .xpack .esql .plan .logical .UnresolvedRelation ;
75
+ import org .elasticsearch .xpack .esql .plan .logical .inference .Completion ;
74
76
import org .elasticsearch .xpack .esql .plan .logical .inference .Rerank ;
75
77
import org .elasticsearch .xpack .esql .plan .logical .local .EsqlProject ;
76
78
import org .elasticsearch .xpack .esql .plugin .EsqlPlugin ;
92
94
import static org .elasticsearch .xpack .esql .EsqlTestUtils .as ;
93
95
import static org .elasticsearch .xpack .esql .EsqlTestUtils .configuration ;
94
96
import static org .elasticsearch .xpack .esql .EsqlTestUtils .emptyInferenceResolution ;
97
+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .getAttributeByName ;
95
98
import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsConstant ;
96
99
import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsIdentifier ;
97
100
import static org .elasticsearch .xpack .esql .EsqlTestUtils .paramAsPattern ;
101
+ import static org .elasticsearch .xpack .esql .EsqlTestUtils .referenceAttribute ;
98
102
import static org .elasticsearch .xpack .esql .EsqlTestUtils .withDefaultLimitWarning ;
99
103
import static org .elasticsearch .xpack .esql .analysis .Analyzer .NO_FIELDS ;
100
104
import static org .elasticsearch .xpack .esql .analysis .AnalyzerTestUtils .analyze ;
@@ -3460,7 +3464,7 @@ public void testResolveRerankInferenceId() {
3460
3464
3461
3465
{
3462
3466
LogicalPlan plan = analyze (
3463
- " FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
3467
+ "FROM books METADATA _score | RERANK \" italian food recipe\" ON title WITH `reranking-inference-id`" ,
3464
3468
"mapping-books.json"
3465
3469
);
3466
3470
Rerank rerank = as (as (plan , Limit .class ).child (), Rerank .class );
@@ -3530,16 +3534,13 @@ public void testResolveRerankFields() {
3530
3534
Filter filter = as (drop .child (), Filter .class );
3531
3535
EsRelation relation = as (filter .child (), EsRelation .class );
3532
3536
3533
- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3534
- assertThat (titleAttribute , notNullValue ());
3537
+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
3538
+ assertThat (getAttributeByName ( relation . output (), "title" ) , notNullValue ());
3535
3539
3536
3540
assertThat (rerank .queryText (), equalTo (string ("italian food recipe" )));
3537
3541
assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
3538
3542
assertThat (rerank .rerankFields (), equalTo (List .of (alias ("title" , titleAttribute ))));
3539
- assertThat (
3540
- rerank .scoreAttribute (),
3541
- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3542
- );
3543
+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
3543
3544
}
3544
3545
3545
3546
{
@@ -3559,15 +3560,11 @@ public void testResolveRerankFields() {
3559
3560
assertThat (rerank .inferenceId (), equalTo (string ("reranking-inference-id" )));
3560
3561
3561
3562
assertThat (rerank .rerankFields (), hasSize (3 ));
3562
- Attribute titleAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "title" )). findFirst (). get ( );
3563
+ Attribute titleAttribute = getAttributeByName ( relation .output (), "title" );
3563
3564
assertThat (titleAttribute , notNullValue ());
3564
3565
assertThat (rerank .rerankFields ().get (0 ), equalTo (alias ("title" , titleAttribute )));
3565
3566
3566
- Attribute descriptionAttribute = relation .output ()
3567
- .stream ()
3568
- .filter (attribute -> attribute .name ().equals ("description" ))
3569
- .findFirst ()
3570
- .get ();
3567
+ Attribute descriptionAttribute = getAttributeByName (relation .output (), "description" );
3571
3568
assertThat (descriptionAttribute , notNullValue ());
3572
3569
Alias descriptionAlias = rerank .rerankFields ().get (1 );
3573
3570
assertThat (descriptionAlias .name (), equalTo ("description" ));
@@ -3576,13 +3573,11 @@ public void testResolveRerankFields() {
3576
3573
equalTo (List .of (descriptionAttribute , literal (0 ), literal (100 )))
3577
3574
);
3578
3575
3579
- Attribute yearAttribute = relation .output (). stream (). filter ( attribute -> attribute . name (). equals ( "year" )). findFirst (). get ( );
3576
+ Attribute yearAttribute = getAttributeByName ( relation .output (), "year" );
3580
3577
assertThat (yearAttribute , notNullValue ());
3581
3578
assertThat (rerank .rerankFields ().get (2 ), equalTo (alias ("yearRenamed" , yearAttribute )));
3582
- assertThat (
3583
- rerank .scoreAttribute (),
3584
- equalTo (relation .output ().stream ().filter (attr -> attr .name ().equals (MetadataAttribute .SCORE )).findFirst ().get ())
3585
- );
3579
+
3580
+ assertThat (rerank .scoreAttribute (), equalTo (getAttributeByName (relation .output (), MetadataAttribute .SCORE )));
3586
3581
}
3587
3582
3588
3583
{
@@ -3614,11 +3609,7 @@ public void testResolveRerankScoreField() {
3614
3609
Filter filter = as (rerank .child (), Filter .class );
3615
3610
EsRelation relation = as (filter .child (), EsRelation .class );
3616
3611
3617
- Attribute metadataScoreAttribute = relation .output ()
3618
- .stream ()
3619
- .filter (attr -> attr .name ().equals (MetadataAttribute .SCORE ))
3620
- .findFirst ()
3621
- .get ();
3612
+ Attribute metadataScoreAttribute = getAttributeByName (relation .output (), MetadataAttribute .SCORE );
3622
3613
assertThat (rerank .scoreAttribute (), equalTo (metadataScoreAttribute ));
3623
3614
assertThat (rerank .output (), hasItem (metadataScoreAttribute ));
3624
3615
}
@@ -3642,6 +3633,116 @@ public void testResolveRerankScoreField() {
3642
3633
}
3643
3634
}
3644
3635
3636
+ public void testResolveCompletionInferenceId () {
3637
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3638
+
3639
+ LogicalPlan plan = analyze ("""
3640
+ FROM books METADATA _score
3641
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3642
+ """ , "mapping-books.json" );
3643
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3644
+ assertThat (completion .inferenceId (), equalTo (string ("completion-inference-id" )));
3645
+ }
3646
+
3647
+ public void testResolveCompletionInferenceIdInvalidTaskType () {
3648
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3649
+
3650
+ assertError (
3651
+ """
3652
+ FROM books METADATA _score
3653
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `reranking-inference-id`
3654
+ """ ,
3655
+ "mapping-books.json" ,
3656
+ new QueryParams (),
3657
+ "cannot use inference endpoint [reranking-inference-id] with task type [rerank] within a Completion command."
3658
+ + " Only inference endpoints with the task type [completion] are supported"
3659
+ );
3660
+ }
3661
+
3662
+ public void testResolveCompletionInferenceMissingInferenceId () {
3663
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3664
+
3665
+ assertError ("""
3666
+ FROM books METADATA _score
3667
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `unknown-inference-id`
3668
+ """ , "mapping-books.json" , new QueryParams (), "unresolved inference [unknown-inference-id]" );
3669
+ }
3670
+
3671
+ public void testResolveCompletionInferenceIdResolutionError () {
3672
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3673
+
3674
+ assertError ("""
3675
+ FROM books METADATA _score
3676
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `error-inference-id`
3677
+ """ , "mapping-books.json" , new QueryParams (), "error with inference resolution" );
3678
+ }
3679
+
3680
+ public void testResolveCompletionTargetField () {
3681
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3682
+
3683
+ LogicalPlan plan = analyze ("""
3684
+ FROM books METADATA _score
3685
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS translation
3686
+ """ , "mapping-books.json" );
3687
+
3688
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3689
+ assertThat (completion .targetField (), equalTo (referenceAttribute ("translation" , DataType .TEXT )));
3690
+ }
3691
+
3692
+ public void testResolveCompletionDefaultTargetField () {
3693
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3694
+
3695
+ LogicalPlan plan = analyze ("""
3696
+ FROM books METADATA _score
3697
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3698
+ """ , "mapping-books.json" );
3699
+
3700
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3701
+ assertThat (completion .targetField (), equalTo (referenceAttribute ("completion" , DataType .TEXT )));
3702
+ }
3703
+
3704
+ public void testResolveCompletionPrompt () {
3705
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3706
+
3707
+ LogicalPlan plan = analyze ("""
3708
+ FROM books METADATA _score
3709
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id`
3710
+ """ , "mapping-books.json" );
3711
+
3712
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3713
+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3714
+
3715
+ assertThat (
3716
+ as (completion .prompt (), Concat .class ).children (),
3717
+ equalTo (List .of (string ("Translate the following text in French\n " ), getAttributeByName (esRelation .output (), "description" )))
3718
+ );
3719
+ }
3720
+
3721
+ public void testResolveCompletionPromptInvalidType () {
3722
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3723
+
3724
+ assertError ("""
3725
+ FROM books METADATA _score
3726
+ | COMPLETION LENGTH(description) WITH `completion-inference-id`
3727
+ """ , "mapping-books.json" , new QueryParams (), "prompt must be of type [text] but is [integer]" );
3728
+ }
3729
+
3730
+ public void testResolveCompletionOutputField () {
3731
+ assumeTrue ("Requires COMPLETION command" , EsqlCapabilities .Cap .COMPLETION .isEnabled ());
3732
+
3733
+ LogicalPlan plan = analyze ("""
3734
+ FROM books METADATA _score
3735
+ | COMPLETION CONCAT("Translate the following text in French\\ n", description) WITH `completion-inference-id` AS description
3736
+ """ , "mapping-books.json" );
3737
+
3738
+ Completion completion = as (as (plan , Limit .class ).child (), Completion .class );
3739
+ assertThat (completion .targetField (), equalTo (referenceAttribute ("description" , DataType .TEXT )));
3740
+
3741
+ EsRelation esRelation = as (completion .child (), EsRelation .class );
3742
+ assertThat (getAttributeByName (completion .output (), "description" ), equalTo (completion .targetField ()));
3743
+ assertThat (getAttributeByName (esRelation .output (), "description" ), not (equalTo (completion .targetField ())));
3744
+ }
3745
+
3645
3746
@ Override
3646
3747
protected IndexAnalyzers createDefaultIndexAnalyzers () {
3647
3748
return super .createDefaultIndexAnalyzers ();
0 commit comments