diff --git a/docs/changelog/128538.yaml b/docs/changelog/128538.yaml new file mode 100644 index 0000000000000..bd4ab34ce2dca --- /dev/null +++ b/docs/changelog/128538.yaml @@ -0,0 +1,5 @@ +pr: 128538 +summary: "Added Mistral Chat Completion support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index b35100d4b5bd8..d6f405959f2b7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -189,6 +189,7 @@ static TransportVersion def(int id) { public static final TransportVersion DATA_STREAM_OPTIONS_API_REMOVE_INCLUDE_DEFAULTS_8_19 = def(8_841_0_41); public static final TransportVersion JOIN_ON_ALIASES_8_19 = def(8_841_0_42); public static final TransportVersion ILM_ADD_SKIP_SETTING_8_19 = def(8_841_0_43); + public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_44); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -282,6 +283,7 @@ static TransportVersion def(int id) { public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES = def(9_087_0_00); public static final TransportVersion JOIN_ON_ALIASES = def(9_088_0_00); public static final TransportVersion ILM_ADD_SKIP_SETTING = def(9_089_0_00); + public static final TransportVersion ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED = def(9_090_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index eef570f1462b2..8e934486be7e1 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -78,6 +78,14 @@ public record UnifiedCompletionRequest( * {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name. */ private static final String MAX_TOKENS_PARAM = "max_tokens_field"; + /** + * Indicates whether to include the `stream_options` field in the JSON output. + * Some providers do not support this field. In such cases, this parameter should be set to "false", + * and the `stream_options` field will be excluded from the output. + * For providers that do support stream options, this parameter is left unset (default behavior), + * which implicitly includes the `stream_options` field in the output. + */ + public static final String INCLUDE_STREAM_OPTIONS_PARAM = "include_stream_options"; /** * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: @@ -91,6 +99,23 @@ public static Params withMaxTokens(String modelId, Params params) { ); } + /** + * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: + * - Key: {@link #MODEL_FIELD}, Value: modelId + * - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #MAX_TOKENS_FIELD} + * - Key: {@link #INCLUDE_STREAM_OPTIONS_PARAM}, Value: "false" + */ + public static Params withMaxTokensAndSkipStreamOptionsField(String modelId, Params params) { + return new DelegatingMapParams( + Map.ofEntries( + Map.entry(MODEL_ID_PARAM, modelId), + Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD), + Map.entry(INCLUDE_STREAM_OPTIONS_PARAM, Boolean.FALSE.toString()) + ), + params + ); + } + /** * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values: * - Key: {@link #MODEL_FIELD}, Value: modelId diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 4de3c9f31d38e..c6e3a1ae73ef7 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -134,7 +134,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(13)); + assertThat(services.size(), equalTo(14)); var providers = providers(services); @@ -154,7 +154,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "openai", "streaming_completion_test_service", "hugging_face", - "amazon_sagemaker" + "amazon_sagemaker", + "mistral" ).toArray() ) ); @@ -162,7 +163,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(7)); + assertThat(services.size(), equalTo(8)); var providers = providers(services); @@ -176,7 +177,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { "streaming_completion_test_service", "hugging_face", "amazon_sagemaker", - "googlevertexai" + "googlevertexai", + "mistral" ).toArray() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 860f8039a56c7..54e8f3102aa45 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -100,6 +100,7 @@ import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankServiceSettings; import org.elasticsearch.xpack.inference.services.jinaai.rerank.JinaAIRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; @@ -266,6 +267,13 @@ private static void addMistralNamedWriteables(List MistralEmbeddingsServiceSettings::new ) ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + MistralChatCompletionServiceSettings.NAME, + MistralChatCompletionServiceSettings::new + ) + ); // note - no task settings for Mistral embeddings... } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java index 489844f4d14de..f94a9771bced0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ErrorMessageResponseEntity.java @@ -21,12 +21,13 @@ * A pattern is emerging in how external providers provide error responses. * * At a minimum, these return: + *

  * {
  *     "error: {
  *         "message": "(error message)"
  *     }
  * }
- *
+ * 
* Others may return additional information such as error codes specific to the service. * * This currently covers error handling for Azure AI Studio, however this pattern diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java new file mode 100644 index 0000000000000..93e1d6388f357 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/streaming/StreamingErrorResponse.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.streaming; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; + +import java.util.Objects; +import java.util.Optional; + +/** + * Represents an error response from a streaming inference service. + * This class extends {@link ErrorResponse} and provides additional fields + * specific to streaming errors, such as code, param, and type. + * An example error response for a streaming service might look like: + *

+ *     {
+ *         "error": {
+ *             "message": "Invalid input",
+ *             "code": "400",
+ *             "param": "input",
+ *             "type": "invalid_request_error"
+ *         }
+ *     }
+ * 
+ * TODO: {@link ErrorMessageResponseEntity} is nearly identical to this, but doesn't parse as many fields. We must remove the duplication. + */ +public class StreamingErrorResponse extends ErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + "streaming_error", + true, + args -> Optional.ofNullable((StreamingErrorResponse) args[0]) + ); + private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( + "streaming_error", + true, + args -> new StreamingErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); + ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code")); + ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param")); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + /** + * Standard error response parser. This can be overridden for those subclasses that + * have a different error response structure. + * @param response The error response as an HttpResult + */ + public static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + /** + * Standard error response parser. This can be overridden for those subclasses that + * have a different error response structure. + * @param response The error response as a string + */ + public static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + @Nullable + private final String code; + @Nullable + private final String param; + private final String type; + + StreamingErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { + super(errorMessage); + this.code = code; + this.param = param; + this.type = Objects.requireNonNull(type); + } + + @Nullable + public String code() { + return code; + } + + @Nullable + public String param() { + return param; + } + + public String type() { + return type; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java index ce030c890218c..4d3e7e2736cad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java @@ -15,6 +15,12 @@ import java.io.IOException; import java.util.Objects; +import static org.elasticsearch.inference.UnifiedCompletionRequest.INCLUDE_STREAM_OPTIONS_PARAM; + +/** + * Represents a unified chat completion request entity. + * This class is used to convert the unified chat input into a format that can be serialized to XContent. + */ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment { public static final String STREAM_FIELD = "stream"; @@ -42,7 +48,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); builder.field(STREAM_FIELD, stream); - if (stream) { + // If request is streamed and skip stream options parameter is not true, include stream options in the request. + if (stream && params.paramAsBoolean(INCLUDE_STREAM_OPTIONS_PARAM, true)) { builder.startObject(STREAM_OPTIONS_FIELD); builder.field(INCLUDE_USAGE_FIELD, true); builder.endObject(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralCompletionResponseHandler.java new file mode 100644 index 0000000000000..e275bff02034d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralCompletionResponseHandler.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; + +/** + * Handles non-streaming completion responses for Mistral models, extending the OpenAI completion response handler. + * This class is specifically designed to handle Mistral's error response format. + */ +public class MistralCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + + /** + * Constructs a MistralCompletionResponseHandler with the specified request type and response parser. + * + * @param requestType The type of request being handled (e.g., "mistral completions"). + * @param parseFunction The function to parse the response. + */ + public MistralCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, MistralErrorResponse::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java index d059545ca1ea3..8e9f96efff421 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java @@ -9,6 +9,7 @@ public class MistralConstants { public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings"; + public static final String API_COMPLETIONS_PATH = "https://api.mistral.ai/v1/chat/completions"; // note - there is no bounds information available from Mistral, // so we'll use a sane default here which is the same as Cohere's @@ -18,4 +19,8 @@ public class MistralConstants { public static final String MODEL_FIELD = "model"; public static final String INPUT_FIELD = "input"; public static final String ENCODING_FORMAT_FIELD = "encoding_format"; + public static final String MAX_TOKENS_FIELD = "max_tokens"; + public static final String DETAIL_FIELD = "detail"; + public static final String MSG_FIELD = "msg"; + public static final String MESSAGE_FIELD = "message"; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java index 73342edaaff15..391a549df924a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralEmbeddingsRequestManager.java @@ -22,7 +22,7 @@ import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity; import org.elasticsearch.xpack.inference.services.azureopenai.response.AzureMistralOpenAiExternalResponseHandler; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.mistral.request.MistralEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest; import org.elasticsearch.xpack.inference.services.mistral.response.MistralEmbeddingsResponseEntity; import java.util.List; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java new file mode 100644 index 0000000000000..57219a03b3bdb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralModel.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; + +/** + * Represents a Mistral model that can be used for inference tasks. + * This class extends RateLimitGroupingModel to handle rate limiting based on model and API key. + */ +public abstract class MistralModel extends RateLimitGroupingModel { + protected String model; + protected URI uri; + protected RateLimitSettings rateLimitSettings; + + protected MistralModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected MistralModel(RateLimitGroupingModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + public String model() { + return this.model; + } + + public URI uri() { + return this.uri; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public int rateLimitGroupingHash() { + return 0; + } + + // Needed for testing only + public void setURI(String newUri) { + try { + this.uri = new URI(newUri); + } catch (URISyntaxException e) { + // swallow any error + } + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 558b7e255f2b4..b11feb117d761 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -30,7 +30,10 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; @@ -39,8 +42,11 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -48,6 +54,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -56,14 +63,26 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; +/** + * MistralService is an implementation of the SenderService that handles inference tasks + * using Mistral models. It supports text embedding, completion, and chat completion tasks. + * The service uses MistralActionCreator to create actions for executing inference requests. + */ public class MistralService extends SenderService { public static final String NAME = "mistral"; private static final String SERVICE_NAME = "Mistral"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING); + private static final EnumSet supportedTaskTypes = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new MistralUnifiedChatCompletionResponseHandler( + "mistral chat completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -79,11 +98,16 @@ protected void doInfer( ) { var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); - if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { - var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(inputs, timeout, listener); - } else { - listener.onFailure(createInvalidModelException(model)); + switch (model) { + case MistralEmbeddingsModel mistralEmbeddingsModel: + mistralEmbeddingsModel.accept(actionCreator, taskSettings).execute(inputs, timeout, listener); + break; + case MistralChatCompletionModel mistralChatCompletionModel: + mistralChatCompletionModel.accept(actionCreator).execute(inputs, timeout, listener); + break; + default: + listener.onFailure(createInvalidModelException(model)); + break; } } @@ -99,7 +123,24 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof MistralChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + MistralChatCompletionModel mistralChatCompletionModel = (MistralChatCompletionModel) model; + var overriddenModel = MistralChatCompletionModel.of(mistralChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new MistralChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = MistralActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); } @Override @@ -162,7 +203,7 @@ public void parseRequestConfig( ); } - MistralEmbeddingsModel model = createModel( + MistralModel model = createModel( modelId, taskType, serviceSettingsMap, @@ -184,7 +225,7 @@ public void parseRequestConfig( } @Override - public Model parsePersistedConfigWithSecrets( + public MistralModel parsePersistedConfigWithSecrets( String modelId, TaskType taskType, Map config, @@ -211,7 +252,7 @@ public Model parsePersistedConfigWithSecrets( } @Override - public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + public MistralModel parsePersistedConfig(String modelId, TaskType taskType, Map config) { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); @@ -236,7 +277,12 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.V_8_15_0; } - private static MistralEmbeddingsModel createModel( + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + } + + private static MistralModel createModel( String modelId, TaskType taskType, Map serviceSettings, @@ -246,23 +292,26 @@ private static MistralEmbeddingsModel createModel( String failureMessage, ConfigurationParseContext context ) { - if (taskType == TaskType.TEXT_EMBEDDING) { - return new MistralEmbeddingsModel( - modelId, - taskType, - NAME, - serviceSettings, - taskSettings, - chunkingSettings, - secretSettings, - context - ); + switch (taskType) { + case TEXT_EMBEDDING: + return new MistralEmbeddingsModel( + modelId, + taskType, + NAME, + serviceSettings, + taskSettings, + chunkingSettings, + secretSettings, + context + ); + case CHAT_COMPLETION, COMPLETION: + return new MistralChatCompletionModel(modelId, taskType, NAME, serviceSettings, secretSettings, context); + default: + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); } - - throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); } - private MistralEmbeddingsModel createModelFromPersistent( + private MistralModel createModelFromPersistent( String inferenceEntityId, TaskType taskType, Map serviceSettings, @@ -284,7 +333,7 @@ private MistralEmbeddingsModel createModelFromPersistent( } @Override - public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { + public MistralEmbeddingsModel updateModelWithEmbeddingDetails(Model model, int embeddingSize) { if (model instanceof MistralEmbeddingsModel embeddingsModel) { var serviceSettings = embeddingsModel.getServiceSettings(); @@ -304,6 +353,10 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { } } + /** + * Configuration class for the Mistral inference service. + * It provides the settings and configurations required for the service. + */ public static class Configuration { public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..a9d6df687fe99 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mistral.response.MistralErrorResponse; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; + +/** + * Handles streaming chat completion responses and error parsing for Mistral inference endpoints. + * Adapts the OpenAI handler to support Mistral's error schema. + */ +public class MistralUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String MISTRAL_ERROR = "mistral_error"; + + public MistralUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, MistralErrorResponse::fromResponse); + } + + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof MistralErrorResponse + ? new UnifiedChatCompletionException(restStatus, errorMessage, MISTRAL_ERROR, restStatus.name().toLowerCase(Locale.ROOT)) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java index c47a47893eb80..7c96ccc2c592c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreator.java @@ -7,19 +7,41 @@ package org.elasticsearch.xpack.inference.services.mistral.action; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.mistral.MistralCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.mistral.MistralEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.core.Strings.format; +/** + * MistralActionCreator is responsible for creating executable actions for Mistral models. + * It implements the MistralActionVisitor interface to provide specific implementations + * for different types of Mistral models. + */ public class MistralActionCreator implements MistralActionVisitor { + + public static final String COMPLETION_ERROR_PREFIX = "Mistral completions"; + static final String USER_ROLE = "user"; + static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler( + "mistral completions", + OpenAiChatCompletionResponseEntity::fromResponse + ); private final Sender sender; private final ServiceComponents serviceComponents; @@ -35,7 +57,32 @@ public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map( + serviceComponents.threadPool(), + chatCompletionModel, + COMPLETION_HANDLER, + inputs -> new MistralChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + /** + * Builds an error message for Mistral actions. + * + * @param requestType The type of request (e.g., TEXT_EMBEDDING, COMPLETION). + * @param inferenceId The ID of the inference entity. + * @return A formatted error message. + */ + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format("Failed to send Mistral %s request from inference entity id [%s]", requestType.toString(), inferenceId); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java index d7618838c33c4..5f494e4d65477 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionVisitor.java @@ -8,10 +8,33 @@ package org.elasticsearch.xpack.inference.services.mistral.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; import java.util.Map; +/** + * Interface for creating {@link ExecutableAction} instances for Mistral models. + *

+ * This interface is used to create {@link ExecutableAction} instances for different types of Mistral models, such as + * {@link MistralEmbeddingsModel} and {@link MistralChatCompletionModel}. + */ public interface MistralActionVisitor { + + /** + * Creates an {@link ExecutableAction} for the given {@link MistralEmbeddingsModel}. + * + * @param embeddingsModel The model to create the action for. + * @param taskSettings The task settings to use. + * @return An {@link ExecutableAction} for the given model. + */ ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings); + + /** + * Creates an {@link ExecutableAction} for the given {@link MistralChatCompletionModel}. + * + * @param chatCompletionModel The model to create the action for. + * @return An {@link ExecutableAction} for the given model. + */ + ExecutableAction create(MistralChatCompletionModel chatCompletionModel); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java new file mode 100644 index 0000000000000..03fe502a82807 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModel.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mistral.MistralModel; +import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_COMPLETIONS_PATH; + +/** + * Represents a Mistral chat completion model. + * This class extends RateLimitGroupingModel to handle rate limiting based on model and API key. + */ +public class MistralChatCompletionModel extends MistralModel { + + /** + * Constructor for MistralChatCompletionModel. + * + * @param inferenceEntityId The unique identifier for the inference entity. + * @param taskType The type of task this model is designed for. + * @param service The name of the service this model belongs to. + * @param serviceSettings The settings specific to the Mistral chat completion service. + * @param secrets The secrets required for accessing the service. + * @param context The context for parsing configuration settings. + */ + public MistralChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + MistralChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + /** + * Creates a new MistralChatCompletionModel with overridden service settings. + * + * @param model The original MistralChatCompletionModel. + * @param request The UnifiedCompletionRequest containing the model override. + * @return A new MistralChatCompletionModel with the overridden model ID. + */ + public static MistralChatCompletionModel of(MistralChatCompletionModel model, UnifiedCompletionRequest request) { + if (request.model() == null) { + // If no model is specified in the request, return the original model + return model; + } + + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new MistralChatCompletionServiceSettings( + request.model(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new MistralChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + public MistralChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + MistralChatCompletionServiceSettings serviceSettings, + DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + private void setPropertiesFromServiceSettings(MistralChatCompletionServiceSettings serviceSettings) { + this.model = serviceSettings.modelId(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + setEndpointUrl(); + } + + @Override + public int rateLimitGroupingHash() { + return Objects.hash(model, getSecretSettings().apiKey()); + } + + private void setEndpointUrl() { + try { + this.uri = new URI(API_COMPLETIONS_PATH); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public MistralChatCompletionServiceSettings getServiceSettings() { + return (MistralChatCompletionServiceSettings) super.getServiceSettings(); + } + + /** + * Accepts a visitor to create an executable action for this model. + * + * @param creator The visitor that creates the executable action. + * @return An ExecutableAction that can be executed. + */ + public ExecutableAction accept(MistralActionVisitor creator) { + return creator.create(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..676653d54a560 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionServiceSettings.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mistral.MistralService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; + +/** + * Represents the settings for the Mistral chat completion service. + * This class encapsulates the model ID and rate limit settings for the Mistral chat completion service. + */ +public class MistralChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "mistral_completions_service_settings"; + + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + // default for Mistral is 5 requests / sec + // setting this to 240 (4 requests / sec) is a sane default for us + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240); + + public static MistralChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String model = extractRequiredString(map, MODEL_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + MistralService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MistralChatCompletionServiceSettings(model, rateLimitSettings); + } + + public MistralChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + public MistralChatCompletionServiceSettings(String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_MISTRAL_CHAT_COMPLETION_ADDED; + } + + @Override + public String modelId() { + return this.modelId; + } + + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + this.toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_FIELD, this.modelId); + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MistralChatCompletionServiceSettings that = (MistralChatCompletionServiceSettings) o; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java index a72e830887d0c..48d2fecc5ce13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -10,16 +10,15 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mistral.MistralModel; import org.elasticsearch.xpack.inference.services.mistral.action.MistralActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.net.URISyntaxException; @@ -27,10 +26,11 @@ import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_EMBEDDINGS_PATH; -public class MistralEmbeddingsModel extends Model { - protected String model; - protected URI uri; - protected RateLimitSettings rateLimitSettings; +/** + * Represents a Mistral embeddings model. + * This class extends MistralModel to handle embeddings-specific settings and actions. + */ +public class MistralEmbeddingsModel extends MistralModel { public MistralEmbeddingsModel( String inferenceEntityId, @@ -58,6 +58,20 @@ public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsSer setPropertiesFromServiceSettings(serviceSettings); } + private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { + this.model = serviceSettings.modelId(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + setEndpointUrl(); + } + + private void setEndpointUrl() { + try { + this.uri = new URI(API_EMBEDDINGS_PATH); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + public MistralEmbeddingsModel( String inferenceEntityId, TaskType taskType, @@ -74,51 +88,11 @@ public MistralEmbeddingsModel( setPropertiesFromServiceSettings(serviceSettings); } - private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { - this.model = serviceSettings.modelId(); - this.rateLimitSettings = serviceSettings.rateLimitSettings(); - setEndpointUrl(); - } - @Override public MistralEmbeddingsServiceSettings getServiceSettings() { return (MistralEmbeddingsServiceSettings) super.getServiceSettings(); } - public String model() { - return this.model; - } - - public URI uri() { - return this.uri; - } - - public RateLimitSettings rateLimitSettings() { - return this.rateLimitSettings; - } - - private void setEndpointUrl() { - try { - this.uri = new URI(API_EMBEDDINGS_PATH); - } catch (URISyntaxException e) { - throw new RuntimeException(e); - } - } - - // Needed for testing only - public void setURI(String newUri) { - try { - this.uri = new URI(newUri); - } catch (URISyntaxException e) { - // swallow any error - } - } - - @Override - public DefaultSecretSettings getSecretSettings() { - return (DefaultSecretSettings) super.getSecretSettings(); - } - public ExecutableAction accept(MistralActionVisitor creator, Map taskSettings) { return creator.create(this, taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java new file mode 100644 index 0000000000000..64051ee0d83b1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * Mistral Unified Chat Completion Request + * This class is responsible for creating a request to the Mistral chat completion model. + * It constructs an HTTP POST request with the necessary headers and body content. + */ +public class MistralChatCompletionRequest implements Request { + + private final MistralChatCompletionModel model; + private final UnifiedChatInput chatInput; + + public MistralChatCompletionRequest(UnifiedChatInput chatInput, MistralChatCompletionModel model) { + this.chatInput = Objects.requireNonNull(chatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new MistralChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation for Mistral chat completions + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation for Mistral chat completions + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return chatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..3fe640335c47e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntity.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.request.completion; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +/** + * MistralChatCompletionRequestEntity is responsible for creating the request entity for Mistral chat completion. + * It implements ToXContentObject to allow serialization to XContent format. + */ +public class MistralChatCompletionRequestEntity implements ToXContentObject { + + private final MistralChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + public MistralChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, MistralChatCompletionModel model) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent( + builder, + UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params) + ); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java index c031f87aa5e13..8b772d4b8f2ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.mistral.request; +package org.elasticsearch.xpack.inference.services.mistral.request.embeddings; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequestEntity.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequestEntity.java index 36622f9fb4be7..91837ea9d9541 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/request/embeddings/MistralEmbeddingsRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.mistral.request; +package org.elasticsearch.xpack.inference.services.mistral.request.embeddings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java new file mode 100644 index 0000000000000..02dfb746fae53 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponse.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.response; + +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; + +import java.nio.charset.StandardCharsets; + +/** + * Represents an error response entity for Mistral inference services. + * This class extends ErrorResponse and provides a method to create an instance + * from an HttpResult, attempting to read the body as a UTF-8 string. + * An example error response for Not Found error would look like: + *


+ * {
+ *     "detail": "Not Found"
+ * }
+ * 
+ * An example error response for Bad Request error would look like: + *

+ * {
+ *     "object": "error",
+ *     "message": "Invalid model: wrong-model-name",
+ *     "type": "invalid_model",
+ *     "param": null,
+ *     "code": "1500"
+ * }
+ * 
+ * An example error response for Unauthorized error would look like: + *

+ * {
+ *     "message": "Unauthorized",
+ *     "request_id": "ad95a2165083f20b490f8f78a14bb104"
+ * }
+ * 
+ * An example error response for Unprocessable Entity error would look like: + *

+ * {
+ *     "object": "error",
+ *     "message": {
+ *         "detail": [
+ *             {
+ *                 "type": "greater_than_equal",
+ *                 "loc": [
+ *                     "body",
+ *                     "max_tokens"
+ *                 ],
+ *                 "msg": "Input should be greater than or equal to 0",
+ *                 "input": -10,
+ *                 "ctx": {
+ *                     "ge": 0
+ *                 }
+ *             }
+ *         ]
+ *     },
+ *     "type": "invalid_request_error",
+ *     "param": null,
+ *     "code": null
+ * }
+ * 
+ */ +public class MistralErrorResponse extends ErrorResponse { + + public MistralErrorResponse(String message) { + super(message); + } + + /** + * Creates an ErrorResponse from the given HttpResult. + * Attempts to read the body as a UTF-8 string and constructs a MistralErrorResponseEntity. + * If reading fails, returns a generic UNDEFINED_ERROR. + * + * @param response the HttpResult containing the error response + * @return an ErrorResponse instance + */ + public static ErrorResponse fromResponse(HttpResult response) { + try { + String errorMessage = new String(response.body(), StandardCharsets.UTF_8); + return new MistralErrorResponse(errorMessage); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java index e8d0b2096b240..c80f0713a5405 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandler.java @@ -24,6 +24,7 @@ import java.util.concurrent.Flow; import java.util.function.Function; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; public class OpenAiResponseHandler extends BaseResponseHandler { @@ -40,6 +41,7 @@ public class OpenAiResponseHandler extends BaseResponseHandler { static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens"; static final String CONTENT_TOO_LARGE_MESSAGE = "Please reduce your prompt; or completion length."; + static final String VALIDATION_ERROR_MESSAGE = "Received an input validation error response"; static final String OPENAI_SERVER_BUSY = "Received a server busy error status code"; @@ -86,11 +88,23 @@ protected void checkForFailureStatusCode(Request request, HttpResult result) thr throw new RetryException(false, buildError(AUTHENTICATION, request, result)); } else if (statusCode >= 300 && statusCode < 400) { throw new RetryException(false, buildError(REDIRECTION, request, result)); + } else if (statusCode == 422) { + // OpenAI does not return 422 at the time of writing, but Mistral does and follows most of OpenAI's format. + // TODO: Revisit this in the future to decouple OpenAI and Mistral error handling. + throw new RetryException(false, buildError(VALIDATION_ERROR_MESSAGE, request, result)); + } else if (statusCode == 400) { + throw new RetryException(false, buildError(BAD_REQUEST, request, result)); + } else if (statusCode == 404) { + throw new RetryException(false, buildError(resourceNotFoundError(request), request, result)); } else { throw new RetryException(false, buildError(UNSUCCESSFUL, request, result)); } } + private static String resourceNotFoundError(Request request) { + return format("Resource not found at [%s]", request.getURI()); + } + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index bbc83667c13e7..e1a0117c7bcca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -7,15 +7,8 @@ package org.elasticsearch.xpack.inference.services.openai; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ConstructingObjectParser; -import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -24,18 +17,21 @@ import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.StreamingErrorResponse; import java.util.Locale; -import java.util.Objects; -import java.util.Optional; import java.util.concurrent.Flow; import java.util.function.Function; import static org.elasticsearch.core.Strings.format; +/** + * Handles streaming chat completion responses and error parsing for OpenAI inference endpoints. + * This handler is designed to work with the unified OpenAI chat completion API. + */ public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, OpenAiErrorResponse::fromResponse); + super(requestType, parseFunction, StreamingErrorResponse::fromResponse); } public OpenAiUnifiedChatCompletionResponseHandler( @@ -62,7 +58,7 @@ protected Exception buildError(String message, Request request, HttpResult resul if (request.isStreaming()) { var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof OpenAiErrorResponse oer + return errorResponse instanceof StreamingErrorResponse oer ? new UnifiedChatCompletionException(restStatus, errorMessage, oer.type(), oer.code(), oer.param()) : new UnifiedChatCompletionException( restStatus, @@ -84,8 +80,8 @@ protected Exception buildMidStreamError(Request request, String message, Excepti } public static UnifiedChatCompletionException buildMidStreamError(String inferenceEntityId, String message, Exception e) { - var errorResponse = OpenAiErrorResponse.fromString(message); - if (errorResponse instanceof OpenAiErrorResponse oer) { + var errorResponse = StreamingErrorResponse.fromString(message); + if (errorResponse instanceof StreamingErrorResponse oer) { return new UnifiedChatCompletionException( RestStatus.INTERNAL_SERVER_ERROR, format( @@ -109,85 +105,4 @@ public static UnifiedChatCompletionException buildMidStreamError(String inferenc ); } } - - private static class OpenAiErrorResponse extends ErrorResponse { - private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( - "open_ai_error", - true, - args -> Optional.ofNullable((OpenAiErrorResponse) args[0]) - ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( - "open_ai_error", - true, - args -> new OpenAiErrorResponse((String) args[0], (String) args[1], (String) args[2], (String) args[3]) - ); - - static { - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); - ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("code")); - ERROR_BODY_PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField("param")); - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("type")); - - ERROR_PARSER.declareObjectOrNull( - ConstructingObjectParser.optionalConstructorArg(), - ERROR_BODY_PARSER, - null, - new ParseField("error") - ); - } - - private static ErrorResponse fromResponse(HttpResult response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response.body()) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } - - return ErrorResponse.UNDEFINED_ERROR; - } - - private static ErrorResponse fromString(String response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } - - return ErrorResponse.UNDEFINED_ERROR; - } - - @Nullable - private final String code; - @Nullable - private final String param; - private final String type; - - OpenAiErrorResponse(String errorMessage, @Nullable String code, @Nullable String param, String type) { - super(errorMessage); - this.code = code; - this.param = param; - this.type = Objects.requireNonNull(type); - } - - @Nullable - public String code() { - return code; - } - - @Nullable - public String param() { - return param; - } - - public String type() { - return type; - } - } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java index ecb14c0be35a7..928ed3ff444e6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java @@ -21,15 +21,10 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { public static final String USER_FIELD = "user"; - private static final String MODEL_FIELD = "model"; - private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; - - private final UnifiedChatInput unifiedChatInput; private final OpenAiChatCompletionModel model; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { - this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); this.model = Objects.requireNonNull(model); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java index 2f4eed8df7812..af38ee38e1eff 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -54,6 +54,7 @@ import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; @@ -273,8 +274,10 @@ public void testUnifiedCompletionError() { assertThat( e.getMessage(), equalTo( - "Received an unsuccessful status code for request from inference entity id [inference-id] status" - + " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]" + "Resource not found at [" + + getUrl(webServer) + + "] for request from inference entity id [inference-id]" + + " status [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 0d76311d81fa5..e2850910ac64a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -66,6 +66,7 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -377,15 +378,14 @@ public void testUnifiedCompletionNonStreamingError() throws Exception { } }); var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); - - assertThat(json, is(""" + assertThat(json, is(String.format(Locale.ROOT, """ {\ "error":{\ "code":"not_found",\ - "message":"Received an unsuccessful status code for request from inference entity id [id] status \ + "message":"Resource not found at [%s] for request from inference entity id [id] status \ [404]. Error message: [Model not found.]",\ "type":"hugging_face_error"\ - }}""")); + }}""", getUrl(webServer)))); } catch (IOException ex) { throw new RuntimeException(ex); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 04a7c38229292..4ba9b8aa24394 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; @@ -29,39 +30,53 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.ModelConfigurationsTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; +import org.hamcrest.Matcher; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; @@ -72,12 +87,14 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_KEY_FIELD; +import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionServiceSettingsTests.getServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettingsTests.createRequestSettingsMap; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -118,11 +135,7 @@ public void testParseRequestConfig_CreatesAMistralEmbeddingsModel() throws IOExc service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")), modelVerificationListener ); } @@ -144,8 +157,8 @@ public void testParseRequestConfig_CreatesAMistralEmbeddingsModelWhenChunkingSet "id", TaskType.TEXT_EMBEDDING, getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), + getEmbeddingsServiceSettingsMap(null, null), + getTaskSettingsMap(), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") ), @@ -169,16 +182,289 @@ public void testParseRequestConfig_CreatesAMistralEmbeddingsModelWhenChunkingSet service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesChatCompletionsModel() throws IOException { + var model = "model"; + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(MistralChatCompletionModel.class)); + + var completionsModel = (MistralChatCompletionModel) m; + + assertThat(completionsModel.getServiceSettings().modelId(), is(model)); + assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap(getServiceSettingsMap(model), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsException_WithoutModelId() throws IOException { + var secret = "secret"; + + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(MistralChatCompletionModel.class)); + + var completionsModel = (MistralChatCompletionModel) m; + + assertNull(completionsModel.getServiceSettings().modelId()); + assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + + }, exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat( + exception.getMessage(), + is("Validation Failed: 1: [service_settings] does not contain the required setting [model];") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap(Collections.emptyMap(), getSecretSettingsMap(secret)), modelVerificationListener ); } } + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); + + try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id": "37d683fc0b3949b880529fb20973aca7",\ + "object": "chat.completion.chunk",\ + "created": 1749032579,\ + "model": "mistral-small-latest",\ + "choices": [\ + {\ + "index": 0,\ + "delta": {\ + "content": "Cho"\ + },\ + "finish_reason": "length",\ + "logprobs": null\ + }\ + ],\ + "usage": {\ + "prompt_tokens": 10,\ + "total_tokens": 11,\ + "completion_tokens": 1\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(XContentHelper.stripWhitespace(""" + { + "id": "37d683fc0b3949b880529fb20973aca7", + "choices": [{ + "delta": { + "content": "Cho" + }, + "finish_reason": "length", + "index": 0 + } + ], + "model": "mistral-small-latest", + "object": "chat.completion.chunk", + "usage": { + "completion_tokens": 1, + "prompt_tokens": 10, + "total_tokens": 11 + } + } + """)); + } + } + + public void testUnifiedCompletionNonStreamingNotFoundError() throws Exception { + String responseJson = """ + { + "detail": "Not Found" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + var model = MistralChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + assertThat(json, is(String.format(Locale.ROOT, XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [%s] for request from inference entity id [id] status \ + [404]. Error message: [{\\n \\"detail\\": \\"Not Found\\"\\n}\\n]", + "type" : "mistral_error" + } + }"""), getUrl(webServer)))); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + var model = MistralChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + public void testInfer_StreamRequest_ErrorResponse() { + String responseJson = """ + { + "message": "Unauthorized", + "request_id": "ad95a2165083f20b490f8f78a14bb104" + }"""; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED)); + assertThat(e.getMessage(), equalTo(""" + Received an authentication error status code for request from inference entity id [id] status [401]. Error message: [{ + "message": "Unauthorized", + "request_id": "ad95a2165083f20b490f8f78a14bb104" + }]""")); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new MistralService(mock(), createWithEmptySettings(mock()))) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createService()) { ActionListener modelVerificationListener = ActionListener.wrap( @@ -192,24 +478,37 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti service.parseRequestConfig( "id", TaskType.SPARSE_EMBEDDING, - getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")), modelVerificationListener ); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig( + getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")), + TaskType.TEXT_EMBEDDING + ); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_Completion() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig( + getRequestConfigMap(getServiceSettingsMap("mistral-completion"), getSecretSettingsMap("secret")), + TaskType.COMPLETION + ); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig_ChatCompletion() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig( + getRequestConfigMap(getServiceSettingsMap("mistral-chat-completion"), getSecretSettingsMap("secret")), + TaskType.CHAT_COMPLETION + ); + } + + private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig(Map secret, TaskType chatCompletion) + throws IOException { try (var service = createService()) { - var config = getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), - getSecretSettingsMap("secret") - ); - config.put("extra_key", "value"); + secret.put("extra_key", "value"); ActionListener modelVerificationListener = ActionListener.wrap( model -> fail("Expected exception, but got model: " + model), @@ -222,20 +521,40 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + service.parseRequestConfig("id", chatCompletion, secret, modelVerificationListener); } } public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap( + getEmbeddingsServiceSettingsMap(null, null), + TaskType.TEXT_EMBEDDING + ); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionTaskSettingsMap() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap( + getServiceSettingsMap("mistral-completion"), + TaskType.COMPLETION + ); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionTaskSettingsMap() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap( + getServiceSettingsMap("mistral-chat-completion"), + TaskType.CHAT_COMPLETION + ); + } + + private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap( + Map serviceSettingsMap, + TaskType chatCompletion + ) throws IOException { try (var service = createService()) { var taskSettings = new HashMap(); taskSettings.put("extra_key", "value"); - var config = getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - taskSettings, - getSecretSettingsMap("secret") - ); + var config = getRequestConfigMap(serviceSettingsMap, taskSettings, getSecretSettingsMap("secret")); ActionListener modelVerificationListener = ActionListener.wrap( model -> fail("Expected exception, but got model: " + model), @@ -248,7 +567,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSett } ); - service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); + service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener); } } @@ -257,11 +576,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe var secretSettings = getSecretSettingsMap("secret"); secretSettings.put("extra_key", "value"); - var config = getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), - secretSettings - ); + var config = getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), secretSettings); ActionListener modelVerificationListener = ActionListener.wrap( model -> fail("Expected exception, but got model: " + model), @@ -278,11 +593,42 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSe } } + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("mistral-completion", TaskType.COMPLETION); + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException { + testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap("mistral-chat-completion", TaskType.CHAT_COMPLETION); + } + + private void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap(String modelId, TaskType chatCompletion) + throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(getServiceSettingsMap(modelId), secretSettings); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + ); + } + ); + + service.parseRequestConfig("id", chatCompletion, config, modelVerificationListener); + } + } + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), - getEmbeddingsTaskSettingsMap(), + getEmbeddingsServiceSettingsMap(1024, 512), + getTaskSettingsMap(), getSecretSettingsMap("secret") ); @@ -298,11 +644,33 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOE } } + public void testParsePersistedConfig_CreatesAMistralCompletionModel() throws IOException { + testParsePersistedConfig_CreatesAMistralModel("mistral-completion", TaskType.COMPLETION); + } + + public void testParsePersistedConfig_CreatesAMistralChatCompletionModel() throws IOException { + testParsePersistedConfig_CreatesAMistralModel("mistral-chat-completion", TaskType.CHAT_COMPLETION); + } + + private void testParsePersistedConfig_CreatesAMistralModel(String modelId, TaskType chatCompletion) throws IOException { + try (var service = createService()) { + var config = getPersistedConfigMap(getServiceSettingsMap(modelId), getTaskSettingsMap(), getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralChatCompletionModel.class)); + + var embeddingsModel = (MistralChatCompletionModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), - getEmbeddingsTaskSettingsMap(), + getEmbeddingsServiceSettingsMap(1024, 512), + getTaskSettingsMap(), createRandomChunkingSettingsMap(), getSecretSettingsMap("secret") ); @@ -323,8 +691,8 @@ public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingS public void testParsePersistedConfig_CreatesAMistralEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), - getEmbeddingsTaskSettingsMap(), + getEmbeddingsServiceSettingsMap(1024, 512), + getTaskSettingsMap(), getSecretSettingsMap("secret") ); @@ -354,11 +722,7 @@ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOExcep service.parseRequestConfig( "id", TaskType.SPARSE_EMBEDDING, - getRequestConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(getEmbeddingsServiceSettingsMap(null, null), getTaskSettingsMap(), getSecretSettingsMap("secret")), modelVerificationListener ); } @@ -367,8 +731,8 @@ public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOExcep public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), - getEmbeddingsTaskSettingsMap(), + getEmbeddingsServiceSettingsMap(null, null), + getTaskSettingsMap(), getSecretSettingsMap("secret") ); @@ -384,38 +748,92 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM } } - public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigEmbeddings() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig( + getEmbeddingsServiceSettingsMap(1024, 512), + TaskType.TEXT_EMBEDDING, + instanceOf(MistralEmbeddingsModel.class) + ); + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigCompletion() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig( + getServiceSettingsMap("mistral-completion"), + TaskType.COMPLETION, + instanceOf(MistralChatCompletionModel.class) + ); + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfigChatCompletion() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig( + getServiceSettingsMap("mistral-chat-completion"), + TaskType.CHAT_COMPLETION, + instanceOf(MistralChatCompletionModel.class) + ); + } + + private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig( + Map serviceSettingsMap, + TaskType chatCompletion, + Matcher matcher + ) throws IOException { try (var service = createService()) { - var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); - var taskSettings = getEmbeddingsTaskSettingsMap(); + var taskSettings = getTaskSettingsMap(); var secretSettings = getSecretSettingsMap("secret"); - var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings); config.config().put("extra_key", "value"); - var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets()); - assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + assertThat(model, matcher); } } public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap( + getEmbeddingsServiceSettingsMap(1024, 512), + TaskType.TEXT_EMBEDDING, + instanceOf(MistralEmbeddingsModel.class) + ); + } + + public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInCompletionServiceSettingsMap() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap( + getServiceSettingsMap("mistral-completion"), + TaskType.COMPLETION, + instanceOf(MistralChatCompletionModel.class) + ); + } + + public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInChatCompletionServiceSettingsMap() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap( + getServiceSettingsMap("mistral-chat-completion"), + TaskType.CHAT_COMPLETION, + instanceOf(MistralChatCompletionModel.class) + ); + } + + private void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInServiceSettingsMap( + Map serviceSettingsMap, + TaskType chatCompletion, + Matcher matcher + ) throws IOException { try (var service = createService()) { - var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); - serviceSettings.put("extra_key", "value"); + serviceSettingsMap.put("extra_key", "value"); - var taskSettings = getEmbeddingsTaskSettingsMap(); + var taskSettings = getTaskSettingsMap(); var secretSettings = getSecretSettingsMap("secret"); - var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets()); - assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + assertThat(model, matcher); } } public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException { try (var service = createService()) { - var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); + var serviceSettings = getEmbeddingsServiceSettingsMap(1024, 512); var taskSettings = new HashMap(); taskSettings.put("extra_key", "value"); @@ -429,27 +847,50 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbedding } public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap( + getEmbeddingsServiceSettingsMap(1024, 512), + TaskType.TEXT_EMBEDDING, + instanceOf(MistralEmbeddingsModel.class) + ); + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInCompletionSecretSettingsMap() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap( + getServiceSettingsMap("mistral-completion"), + TaskType.COMPLETION, + instanceOf(MistralChatCompletionModel.class) + ); + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInChatCompletionSecretSettingsMap() throws IOException { + testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap( + getServiceSettingsMap("mistral-chat-completion"), + TaskType.CHAT_COMPLETION, + instanceOf(MistralChatCompletionModel.class) + ); + } + + private void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsSecretSettingsMap( + Map serviceSettingsMap, + TaskType chatCompletion, + Matcher matcher + ) throws IOException { try (var service = createService()) { - var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); - var taskSettings = getEmbeddingsTaskSettingsMap(); + var taskSettings = getTaskSettingsMap(); var secretSettings = getSecretSettingsMap("secret"); secretSettings.put("extra_key", "value"); - var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + var config = getPersistedConfigMap(serviceSettingsMap, taskSettings, secretSettings); - var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + var model = service.parsePersistedConfigWithSecrets("id", chatCompletion, config.config(), config.secrets()); - assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + assertThat(model, matcher); } } public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException { try (var service = createService()) { - var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), - getEmbeddingsTaskSettingsMap(), - Map.of() - ); + var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of()); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); @@ -465,8 +906,8 @@ public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() thro public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createService()) { var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), - getEmbeddingsTaskSettingsMap(), + getEmbeddingsServiceSettingsMap(1024, 512), + getTaskSettingsMap(), createRandomChunkingSettingsMap(), Map.of() ); @@ -485,11 +926,7 @@ public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenC public void testParsePersistedConfig_WithoutSecretsCreatesAnEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createService()) { - var config = getPersistedConfigMap( - getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), - getEmbeddingsTaskSettingsMap(), - Map.of() - ); + var config = getPersistedConfigMap(getEmbeddingsServiceSettingsMap(1024, 512), getTaskSettingsMap(), Map.of()); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); @@ -780,7 +1217,7 @@ public void testGetConfiguration() throws Exception { { "service": "mistral", "name": "Mistral", - "task_types": ["text_embedding"], + "task_types": ["text_embedding", "completion", "chat_completion"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -789,7 +1226,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "model": { "description": "Refer to the Mistral models documentation for the list of available text embedding models.", @@ -798,7 +1235,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -807,7 +1244,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] }, "max_input_tokens": { "description": "Allows you to specify the maximum number of tokens per input.", @@ -816,7 +1253,7 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding"] + "supported_task_types": ["text_embedding", "completion", "chat_completion"] } } } @@ -868,16 +1305,19 @@ private Map getRequestConfigMap( ); } - private static Map getEmbeddingsServiceSettingsMap( - String model, - @Nullable Integer dimensions, - @Nullable Integer maxTokens, - @Nullable SimilarityMeasure similarityMeasure - ) { - return createRequestSettingsMap(model, dimensions, maxTokens, similarityMeasure); + private Map getRequestConfigMap(Map serviceSettings, Map secretSettings) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings)); } - private static Map getEmbeddingsTaskSettingsMap() { + private static Map getEmbeddingsServiceSettingsMap(@Nullable Integer dimensions, @Nullable Integer maxTokens) { + return createRequestSettingsMap("mistral-embed", dimensions, maxTokens, null); + } + + private static Map getTaskSettingsMap() { // no task settings for Mistral embeddings return Map.of(); } @@ -886,25 +1326,4 @@ private static Map getSecretSettingsMap(String apiKey) { return new HashMap<>(Map.of(API_KEY_FIELD, apiKey)); } - private static final String testEmbeddingResultJson = """ - { - "object": "list", - "data": [ - { - "object": "embedding", - "index": 0, - "embedding": [ - 0.0123, - -0.0123 - ] - } - ], - "model": "text-embedding-ada-002-v2", - "usage": { - "prompt_tokens": 8, - "total_tokens": 8 - } - } - """; - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..7564fcf106898 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandlerTests.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MistralUnifiedChatCompletionResponseHandlerTests extends ESTestCase { + private final MistralUnifiedChatCompletionResponseHandler responseHandler = new MistralUnifiedChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailNotFound() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "detail": "Not Found" + } + """); + + var errorJson = invalidResponseJson(responseJson, 404); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "not_found", + "message" : "Resource not found at [https://api.mistral.ai/v1/chat/completions] for request from inference entity id [id] \ + status [404]. Error message: [{\\"detail\\":\\"Not Found\\"}]", + "type" : "mistral_error" + } + }"""))); + } + + public void testFailUnauthorized() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "message": "Unauthorized", + "request_id": "a580d263fb1521778782b22104efb415" + } + """); + + var errorJson = invalidResponseJson(responseJson, 401); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "unauthorized", + "message" : "Received an authentication error status code for request from inference entity id [id] status [401]. Error \ + message: [{\\"message\\":\\"Unauthorized\\",\\"request_id\\":\\"a580d263fb1521778782b22104efb415\\"}]", + "type" : "mistral_error" + } + }"""))); + } + + public void testFailBadRequest() throws IOException { + var responseJson = XContentHelper.stripWhitespace(""" + { + "object": "error", + "message": "Invalid model: mistral-small-l2atest", + "type": "invalid_model", + "param": null, + "code": "1500" + } + """); + + var errorJson = invalidResponseJson(responseJson, 400); + + assertThat(errorJson, is(XContentHelper.stripWhitespace(""" + { + "error" : { + "code" : "bad_request", + "message" : "Received a bad request status code for request from inference entity id [id] status [400]. Error message: \ + [{\\"object\\":\\"error\\",\\"message\\":\\"Invalid model: mistral-small-l2atest\\",\\"type\\":\\"invalid_model\\",\\"par\ + am\\":null,\\"code\\":\\"1500\\"}]", + "type" : "mistral_error" + } + }"""))); + } + + private String invalidResponseJson(String responseJson, int statusCode) throws IOException { + var exception = invalidResponse(responseJson, statusCode); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson, int statusCode) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + + private static Request mockRequest() throws URISyntaxException { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + when(request.getURI()).thenReturn(new URI("https://api.mistral.ai/v1/chat/completions")); + return request; + } + + private static HttpResponse mockErrorResponse(int statusCode) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java new file mode 100644 index 0000000000000..7a943e91af5dc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralActionCreatorTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.logging.ThrottlerManagerTests.mockThrottlerManager; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class MistralActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "chat.completion", + "id": "", + "created": 1745855316, + "model": "/repository", + "system_fingerprint": "3.2.3-sha-a1f3ebe", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 50, + "total_tokens": 58 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertChatCompletionRequest(); + } + } + + public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createChatCompletionFuture( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Mistral completion request from inference entity id " + "[id]. Cause: Required [choices]") + ); + + assertChatCompletionRequest(); + } + } + + private PlainActionFuture createChatCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = MistralChatCompletionModelTests.createCompletionModel("secret", "model"); + model.setURI(getUrl(webServer)); + var actionCreator = new MistralActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertChatCompletionRequest() throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java new file mode 100644 index 0000000000000..e585092b1f780 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/action/MistralChatCompletionActionTests.java @@ -0,0 +1,236 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.mistral.action.MistralActionCreator.USER_ROLE; +import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class MistralChatCompletionActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "9d80f26810ac4e9582f927fcf0512ec7", + "object": "chat.completion", + "created": 1748596419, + "model": "mistral-small-latest", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "tool_calls": null, + "content": "result content" + }, + "finish_reason": "length", + "logprobs": null + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 11, + "completion_tokens": 1 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content")))); + assertThat(webServer.requests(), hasSize(1)); + + MockRequest request = webServer.requests().get(0); + + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed")); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "9d80f26810ac4e9582f927fcf0512ec7", + "object": "chat.completion", + "created": 1748596419, + "model": "mistral-small-latest", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "tool_calls": null, + "content": "result content" + }, + "finish_reason": "length", + "logprobs": null + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 11, + "completion_tokens": 1 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + private ExecutableAction createAction(String url, Sender sender) { + var model = createCompletionModel("secret", "model"); + model.setURI(url); + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + inputs -> new MistralChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + var errorMessage = constructFailedToSendRequestMessage("mistral chat completions"); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions"); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModelTests.java new file mode 100644 index 0000000000000..55019cea5a524 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionModelTests.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class MistralChatCompletionModelTests extends ESTestCase { + + public static MistralChatCompletionModel createCompletionModel(String apiKey, String modelId) { + return new MistralChatCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new MistralChatCompletionServiceSettings(modelId, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static MistralChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) { + MistralChatCompletionModel mistralChatCompletionModel = new MistralChatCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new MistralChatCompletionServiceSettings(modelId, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + mistralChatCompletionModel.setURI(url); + return mistralChatCompletionModel; + } + + public static MistralChatCompletionModel createChatCompletionModel(String apiKey, String modelId) { + return new MistralChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "service", + new MistralChatCompletionServiceSettings(modelId, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static MistralChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) { + MistralChatCompletionModel mistralChatCompletionModel = new MistralChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "service", + new MistralChatCompletionServiceSettings(modelId, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + mistralChatCompletionModel.setURI(url); + return mistralChatCompletionModel; + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel("api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = MistralChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel("api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = MistralChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel("api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = MistralChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = MistralChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..5bf7477e9c828 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/completion/MistralChatCompletionServiceSettingsTests.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.mistral.MistralConstants; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class MistralChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static final String MODEL_ID = "some model"; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = MistralChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + MistralConstants.MODEL_FIELD, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new MistralChatCompletionServiceSettings( + MODEL_ID, + + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_MissingModelId_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> MistralChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT))) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [model];") + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = MistralChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(MistralConstants.MODEL_FIELD, MODEL_ID)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new MistralChatCompletionServiceSettings(MODEL_ID, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = MistralChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + MistralConstants.MODEL_FIELD, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model": "some model", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = MistralChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(MistralConstants.MODEL_FIELD, MODEL_ID)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model": "some model", + "rate_limit": { + "requests_per_minute": 240 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return MistralChatCompletionServiceSettings::new; + } + + @Override + protected MistralChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected MistralChatCompletionServiceSettings mutateInstance(MistralChatCompletionServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, MistralChatCompletionServiceSettingsTests::createRandom); + } + + @Override + protected MistralChatCompletionServiceSettings mutateInstanceForVersion( + MistralChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static MistralChatCompletionServiceSettings createRandom() { + var modelId = randomAlphaOfLength(8); + + return new MistralChatCompletionServiceSettings(modelId, RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String model) { + var map = new HashMap(); + + map.put(MistralConstants.MODEL_FIELD, model); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntityTests.java index f833a53f85323..e75ea151a6d82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestEntityTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequestEntity; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestTests.java index e0ef2d99d7c8c..e0ae3493af712 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/MistralEmbeddingsRequestTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.services.mistral.MistralConstants; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests; +import org.elasticsearch.xpack.inference.services.mistral.request.embeddings.MistralEmbeddingsRequest; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..f968f1b84d75b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestEntityTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.request.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; + +import static org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests.createCompletionModel; + +public class MistralChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + MistralChatCompletionModel model = createCompletionModel("api-key", "test-endpoint"); + + MistralChatCompletionRequestEntity entity = new MistralChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true + } + """; + assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java new file mode 100644 index 0000000000000..4a70861932d28 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/request/completion/MistralChatCompletionRequestTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.request.completion; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.mistral.MistralConstants; +import org.elasticsearch.xpack.inference.services.mistral.completion.MistralChatCompletionModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class MistralChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest_WithStreaming() throws IOException { + var request = createRequest("secret", randomAlphaOfLength(15), "model", true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(true)); + } + + public void testTruncate_DoesNotReduceInputTextSize() throws IOException { + String input = randomAlphaOfLength(5); + var request = createRequest("secret", input, "model", true); + var truncatedRequest = request.truncate(); + assertThat(request.getURI().toString(), is(MistralConstants.API_COMPLETIONS_PATH)); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(4)); + + // We do not truncate for Hugging Face chat completions + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertNull(requestMap.get("stream_options")); // Mistral does not use stream options + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("secret", randomAlphaOfLength(5), "model", true); + assertNull(request.getTruncationInfo()); + } + + public static MistralChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model) { + return createRequest(apiKey, input, model, false); + } + + public static MistralChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream) { + var chatCompletionModel = MistralChatCompletionModelTests.createCompletionModel(apiKey, model); + return new MistralChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponseTests.java new file mode 100644 index 0000000000000..34cf667846996 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/response/MistralErrorResponseTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.response; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Mockito.mock; + +public class MistralErrorResponseTests extends ESTestCase { + + public static final String ERROR_RESPONSE_JSON = """ + { + "error": "A valid user token is required" + } + """; + + public void testFromResponse() { + var errorResponse = MistralErrorResponse.fromResponse( + new HttpResult(mock(HttpResponse.class), ERROR_RESPONSE_JSON.getBytes(StandardCharsets.UTF_8)) + ); + assertNotNull(errorResponse); + assertEquals(ERROR_RESPONSE_JSON, errorResponse.getErrorMessage()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandlerTests.java index ce69473692d39..d8f4157a2fbc6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiResponseHandlerTests.java @@ -100,7 +100,7 @@ public void testCheckForFailureStatusCode() { assertFalse(retryException.shouldRetry()); assertThat( retryException.getCause().getMessage(), - containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]") + containsString("Received a bad request status code for request from inference entity id [id] status [400]") ); assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST)); // 400 is not flagged as a content too large when the error message is different @@ -112,7 +112,7 @@ public void testCheckForFailureStatusCode() { assertFalse(retryException.shouldRetry()); assertThat( retryException.getCause().getMessage(), - containsString("Received an unsuccessful status code for request from inference entity id [id] status [400]") + containsString("Received a bad request status code for request from inference entity id [id] status [400]") ); assertThat(((ElasticsearchStatusException) retryException.getCause()).status(), is(RestStatus.BAD_REQUEST)); // 401 diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 26a0a5b6ef770..c19eb664e88ac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -65,6 +65,7 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -1153,14 +1154,14 @@ public void testUnifiedCompletionError() throws Exception { }); var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); - assertThat(json, is(""" + assertThat(json, is(String.format(Locale.ROOT, """ {\ "error":{\ "code":"model_not_found",\ - "message":"Received an unsuccessful status code for request from inference entity id [id] status \ + "message":"Resource not found at [%s] for request from inference entity id [id] status \ [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\ "type":"invalid_request_error"\ - }}""")); + }}""", getUrl(webServer)))); } catch (IOException ex) { throw new RuntimeException(ex); }