Skip to content

Commit d1ad917

Browse files
Jan-Kazlouski-elasticelasticsearchmachinejonathan-buttner
authored
Add Hugging Face Chat Completion support to Inference Plugin (#127254)
* Add Hugging Face Chat Completion support to Inference Plugin * Add support for streaming chat completion task for HuggingFace * [CI] Auto commit changes from spotless * Add support for non-streaming completion task for HuggingFace * Remove RequestManager for HF Chat Completion Task * Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests * Refactored Hugging Face Action Creator, added Unit Tests * Add Hugging Face Server Test * [CI] Auto commit changes from spotless * Removed parameters from media type for Chat Completion Request and unit tests * Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT * Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService * Update minimal supported version and add Hugging Face transport version constants * Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests * Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests * Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param * Removed getFirst() method calls for backport compatibility * Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing * Refactored tests to use stripWhitespace method for readability * Refactored javadoc for HuggingFaceService * Renamed HF chat completion TransportVersion constant names * Added random string generation in unit test * Refactored javadocs for HuggingFace requests * Refactored tests to reduce duplication * Added changelog file * Add HuggingFaceChatCompletionResponseHandler and associated tests * Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types * Refactor HuggingFace error handling to improve response structure and add streaming support * Allowing null function name for hugging face models --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Jonathan Buttner <[email protected]>
1 parent 54f2668 commit d1ad917

File tree

30 files changed

+2334
-49
lines changed

30 files changed

+2334
-49
lines changed

docs/changelog/127254.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 127254
2+
summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ static TransportVersion def(int id) {
175175
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
176176
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
177177
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
178+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);
178179
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
179180
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
180181
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -255,6 +256,7 @@ static TransportVersion def(int id) {
255256
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
256257
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
257258
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
259+
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);
258260

259261
/*
260262
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
123123

124124
public void testGetServicesWithCompletionTaskType() throws IOException {
125125
List<Object> services = getServices(TaskType.COMPLETION);
126-
assertThat(services.size(), equalTo(10));
126+
assertThat(services.size(), equalTo(11));
127127

128128
var providers = providers(services);
129129

@@ -140,19 +140,23 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
140140
"deepseek",
141141
"googleaistudio",
142142
"openai",
143-
"streaming_completion_test_service"
143+
"streaming_completion_test_service",
144+
"hugging_face"
144145
).toArray()
145146
)
146147
);
147148
}
148149

149150
public void testGetServicesWithChatCompletionTaskType() throws IOException {
150151
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
151-
assertThat(services.size(), equalTo(4));
152+
assertThat(services.size(), equalTo(5));
152153

153154
var providers = providers(services);
154155

155-
assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
156+
assertThat(
157+
providers,
158+
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
159+
);
156160
}
157161

158162
public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
7979
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
8080
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
81+
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
8182
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
8283
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
8384
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
@@ -357,6 +358,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
357358
namedWriteables.add(
358359
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
359360
);
361+
namedWriteables.add(
362+
new NamedWriteableRegistry.Entry(
363+
ServiceSettings.class,
364+
HuggingFaceChatCompletionServiceSettings.NAME,
365+
HuggingFaceChatCompletionServiceSettings::new
366+
)
367+
);
360368
}
361369

362370
private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.huggingface;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.rest.RestStatus;
12+
import org.elasticsearch.xcontent.ConstructingObjectParser;
13+
import org.elasticsearch.xcontent.ParseField;
14+
import org.elasticsearch.xcontent.XContentFactory;
15+
import org.elasticsearch.xcontent.XContentParser;
16+
import org.elasticsearch.xcontent.XContentParserConfiguration;
17+
import org.elasticsearch.xcontent.XContentType;
18+
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
19+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
20+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
21+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
22+
import org.elasticsearch.xpack.inference.external.request.Request;
23+
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
24+
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
25+
26+
import java.util.Locale;
27+
import java.util.Optional;
28+
29+
import static org.elasticsearch.core.Strings.format;
30+
31+
/**
32+
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
33+
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
34+
*/
35+
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
36+
37+
private static final String HUGGING_FACE_ERROR = "hugging_face_error";
38+
39+
public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
40+
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
41+
}
42+
43+
@Override
44+
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
45+
assert request.isStreaming() : "Only streaming requests support this format";
46+
var responseStatusCode = result.response().getStatusLine().getStatusCode();
47+
if (request.isStreaming()) {
48+
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
49+
var restStatus = toRestStatus(responseStatusCode);
50+
return errorResponse instanceof HuggingFaceErrorResponseEntity
51+
? new UnifiedChatCompletionException(
52+
restStatus,
53+
errorMessage,
54+
HUGGING_FACE_ERROR,
55+
restStatus.name().toLowerCase(Locale.ROOT)
56+
)
57+
: new UnifiedChatCompletionException(
58+
restStatus,
59+
errorMessage,
60+
createErrorType(errorResponse),
61+
restStatus.name().toLowerCase(Locale.ROOT)
62+
);
63+
} else {
64+
return super.buildError(message, request, result, errorResponse);
65+
}
66+
}
67+
68+
@Override
69+
protected Exception buildMidStreamError(Request request, String message, Exception e) {
70+
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
71+
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
72+
return new UnifiedChatCompletionException(
73+
RestStatus.INTERNAL_SERVER_ERROR,
74+
format(
75+
"%s for request from inference entity id [%s]. Error message: [%s]",
76+
SERVER_ERROR_OBJECT,
77+
request.getInferenceEntityId(),
78+
errorResponse.getErrorMessage()
79+
),
80+
HUGGING_FACE_ERROR,
81+
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
82+
);
83+
} else if (e != null) {
84+
return UnifiedChatCompletionException.fromThrowable(e);
85+
} else {
86+
return new UnifiedChatCompletionException(
87+
RestStatus.INTERNAL_SERVER_ERROR,
88+
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
89+
createErrorType(errorResponse),
90+
"stream_error"
91+
);
92+
}
93+
}
94+
95+
private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
96+
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
97+
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
98+
: null;
99+
}
100+
101+
/**
102+
* Represents a structured error response specifically for streaming operations
103+
* using HuggingFace APIs. This is separate from non-streaming error responses,
104+
* which are handled by {@link HuggingFaceErrorResponseEntity}.
105+
* An example error response for failed field validation for streaming operation would look like
106+
* <code>
107+
* {
108+
* "error": "Input validation error: cannot compile regex from schema",
109+
* "http_status_code": 422
110+
* }
111+
* </code>
112+
*/
113+
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
114+
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
115+
HUGGING_FACE_ERROR,
116+
true,
117+
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
118+
);
119+
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
120+
new ConstructingObjectParser<>(
121+
HUGGING_FACE_ERROR,
122+
true,
123+
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
124+
);
125+
126+
static {
127+
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
128+
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));
129+
130+
ERROR_PARSER.declareObjectOrNull(
131+
ConstructingObjectParser.optionalConstructorArg(),
132+
ERROR_BODY_PARSER,
133+
null,
134+
new ParseField("error")
135+
);
136+
}
137+
138+
/**
139+
* Parses a streaming HuggingFace error response from a JSON string.
140+
*
141+
* @param response the raw JSON string representing an error
142+
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
143+
*/
144+
private static ErrorResponse fromString(String response) {
145+
try (
146+
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
147+
.createParser(XContentParserConfiguration.EMPTY, response)
148+
) {
149+
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
150+
} catch (Exception e) {
151+
// swallow the error
152+
}
153+
154+
return ErrorResponse.UNDEFINED_ERROR;
155+
}
156+
157+
@Nullable
158+
private final Integer httpStatusCode;
159+
160+
StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
161+
super(errorMessage);
162+
this.httpStatusCode = httpStatusCode;
163+
}
164+
165+
@Nullable
166+
public Integer httpStatusCode() {
167+
return httpStatusCode;
168+
}
169+
170+
}
171+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99

1010
import org.elasticsearch.common.settings.SecureString;
1111
import org.elasticsearch.core.Nullable;
12-
import org.elasticsearch.inference.Model;
1312
import org.elasticsearch.inference.ModelConfigurations;
1413
import org.elasticsearch.inference.ModelSecrets;
1514
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
15+
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
1616
import org.elasticsearch.xpack.inference.services.ServiceUtils;
1717
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
1818
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
19+
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
1920

2021
import java.util.Objects;
2122

22-
public abstract class HuggingFaceModel extends Model {
23+
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
2324
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
2425
private final SecureString apiKey;
2526

@@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
3839
return rateLimitServiceSettings;
3940
}
4041

42+
@Override
43+
public int rateLimitGroupingHash() {
44+
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
45+
}
46+
47+
@Override
48+
public RateLimitSettings rateLimitSettings() {
49+
return rateLimitServiceSettings.rateLimitSettings();
50+
}
51+
4152
public SecureString apiKey() {
4253
return apiKey;
4354
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
2020
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
2121
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
22-
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
22+
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;
2323

2424
import java.util.List;
2525
import java.util.Objects;
@@ -64,7 +64,7 @@ public void execute(
6464
) {
6565
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
6666
var truncatedInput = truncate(docsInput, model.getTokenLimit());
67-
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
67+
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);
6868

6969
execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
7070
}

0 commit comments

Comments
 (0)