From e4b3d5619a80f314fc9b1acd080d07626259b54d Mon Sep 17 00:00:00 2001 From: Ying Date: Thu, 23 Jan 2025 13:39:29 -0500 Subject: [PATCH 1/5] Adding support for binary embedding type to Cohere service embedding type --- .../org/elasticsearch/TransportVersions.java | 1 + .../CohereEmbeddingsResponseEntity.java | 3 + .../embeddings/CohereEmbeddingType.java | 19 +- .../CohereEmbeddingsRequestEntityTests.java | 32 +++ .../cohere/CohereEmbeddingsRequestTests.java | 47 +++++ .../CohereEmbeddingsResponseEntityTests.java | 182 +++++++++++++++++- .../embeddings/CohereEmbeddingTypeTests.java | 25 +++ .../CohereEmbeddingsServiceSettingsTests.java | 16 +- 8 files changed, 320 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 6fb4703f5153d..4f575c9f384af 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -164,6 +164,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_ROLLOVER_LEGACY_INDICES = def(8_830_00_0); public static final TransportVersion ADD_INCLUDE_FAILURE_INDICES_OPTION = def(8_831_00_0); public static final TransportVersion ESQL_RESPONSE_PARTIAL = def(8_832_00_0); + public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(8_833_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index 3fa9635d38e8c..2d338e43600d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -43,6 +43,9 @@ public class CohereEmbeddingsResponseEntity { toLowerCase(CohereEmbeddingType.FLOAT), CohereEmbeddingsResponseEntity::parseFloatEmbeddingsArray, toLowerCase(CohereEmbeddingType.INT8), + CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray, + toLowerCase(CohereEmbeddingType.BINARY), + // Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java index 11e405df3cde9..d756533ee7311 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingType.java @@ -36,18 +36,29 @@ public enum CohereEmbeddingType { /** * This is a synonym for INT8 */ - BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8); + BYTE(DenseVectorFieldMapper.ElementType.BYTE, RequestConstants.INT8), + /** + * Use this when you want to get back binary embeddings. Valid only for v3 models. + */ + BIT(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT), + /** + * This is a synonym for BIT + */ + BINARY(DenseVectorFieldMapper.ElementType.BIT, RequestConstants.BIT); private static final class RequestConstants { private static final String FLOAT = "float"; private static final String INT8 = "int8"; + private static final String BIT = "binary"; } private static final Map ELEMENT_TYPE_TO_COHERE_EMBEDDING = Map.of( DenseVectorFieldMapper.ElementType.FLOAT, FLOAT, DenseVectorFieldMapper.ElementType.BYTE, - BYTE + BYTE, + DenseVectorFieldMapper.ElementType.BIT, + BIT ); static final EnumSet SUPPORTED_ELEMENT_TYPES = EnumSet.copyOf( ELEMENT_TYPE_TO_COHERE_EMBEDDING.keySet() @@ -116,6 +127,10 @@ public static CohereEmbeddingType translateToVersion(CohereEmbeddingType embeddi return INT8; } + if (version.before(TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED) && embeddingType == BIT) { + return INT8; + } + return embeddingType; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java index 8c8aeba4a0a07..8ca5a91e83429 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestEntityTests.java @@ -72,6 +72,38 @@ public void testXContent_InputTypeSearch_EmbeddingTypesByte_TruncateNone() throw {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); } + public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { + var entity = new CohereEmbeddingsRequestEntity( + List.of("abc"), + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + "model", + CohereEmbeddingType.BINARY + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { + var entity = new CohereEmbeddingsRequestEntity( + List.of("abc"), + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + "model", + CohereEmbeddingType.BIT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java index d30b809603eef..f3664b1e23a03 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/cohere/CohereEmbeddingsRequestTests.java @@ -145,6 +145,53 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th ); } + public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException { + var request = createRequest( + List.of("abc"), + CohereEmbeddingsModelTests.createModel( + "url", + "secret", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), + null, + null, + "model", + CohereEmbeddingType.BIT + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + "search_query", + "embedding_types", + List.of("binary"), + "truncate", + "end" + ) + ) + ); + } + public void testCreateRequest_TruncateNone() throws IOException { var request = createRequest( List.of("abc"), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java index 691064b947e23..5b7231999f88e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java @@ -226,6 +226,53 @@ public void testFromResponse_ParsesBytes() throws IOException { ); } + public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "binary": [ + [ + -55, + 74, + 101, + 67, + 83 + ] + ] + }, + "meta": { + "api_version": { + "version": "2" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + + InferenceTextEmbeddingByteResults parsedResults = (InferenceTextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is( + List.of( + new InferenceTextEmbeddingByteResults.InferenceByteEmbedding( + new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 } + ) + ) + ) + ); + } + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { String responseJson = """ { @@ -318,6 +365,63 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat() throw ); } + public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary() throws IOException { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello", + "goodbye" + ], + "embeddings": { + "binary": [ + [ + -55, + 74, + 101, + 67 + ], + [ + 34, + -64, + 97, + 65, + -42 + ] + ] + }, + "meta": { + "api_version": { + "version": "2" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + + InferenceTextEmbeddingByteResults parsedResults = (InferenceTextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + MatcherAssert.assertThat( + parsedResults.embeddings(), + is( + List.of( + new InferenceTextEmbeddingByteResults.InferenceByteEmbedding( + new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 } + ), + new InferenceTextEmbeddingByteResults.InferenceByteEmbedding( + new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 } + ) + ) + ) + ); + } + public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { String responseJson = """ { @@ -433,6 +537,82 @@ public void testFromResponse_FailsWhenEmbeddingsByteValue_IsOutsideByteRange_Pos MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte")); } + public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Negative() { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "binary": [ + [ + -129, + 127 + ] + ] + }, + "meta": { + "api_version": { + "version": "2" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Value [-129] is out of range for a byte")); + } + + public void testFromResponse_FailsWhenEmbeddingsBinaryValue_IsOutsideByteRange_Positive() { + String responseJson = """ + { + "id": "3198467e-399f-4d4a-aa2c-58af93bd6dc4", + "texts": [ + "hello" + ], + "embeddings": { + "binary": [ + [ + -128, + 128 + ] + ] + }, + "meta": { + "api_version": { + "version": "2" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + + var thrownException = expectThrows( + IllegalArgumentException.class, + () -> CohereEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + MatcherAssert.assertThat(thrownException.getMessage(), is("Value [128] is out of range for a byte")); + } + public void testFromResponse_FailsToFindAValidEmbeddingType() { String responseJson = """ { @@ -470,7 +650,7 @@ public void testFromResponse_FailsToFindAValidEmbeddingType() { MatcherAssert.assertThat( thrownException.getMessage(), - is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [float, int8]") + is("Failed to find a supported embedding type in the Cohere embeddings response. Supported types are [binary, float, int8]") ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java index 3aa423d5bbafd..e751e3df800cd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingTypeTests.java @@ -50,6 +50,27 @@ public void testTranslateToVersion_ReturnsFloat_WhenVersionOnByteEnumAddition_Wh ); } + public void testTranslateToVersion_ReturnsInt8_WhenVersionIsBeforeBitEnumAddition_WhenSpecifyingBit() { + assertThat( + CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, new TransportVersion(8_832_00_0)), + is(CohereEmbeddingType.INT8) + ); + } + + public void testTranslateToVersion_ReturnsBit_WhenVersionOnBitEnumAddition_WhenSpecifyingBit() { + assertThat( + CohereEmbeddingType.translateToVersion(CohereEmbeddingType.BIT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED), + is(CohereEmbeddingType.BIT) + ); + } + + public void testTranslateToVersion_ReturnsFloat_WhenVersionOnBitEnumAddition_WhenSpecifyingFloat() { + assertThat( + CohereEmbeddingType.translateToVersion(CohereEmbeddingType.FLOAT, TransportVersions.COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED), + is(CohereEmbeddingType.FLOAT) + ); + } + public void testFromElementType_CovertsFloatToCohereEmbeddingTypeFloat() { assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.FLOAT), is(CohereEmbeddingType.FLOAT)); } @@ -57,4 +78,8 @@ public void testFromElementType_CovertsFloatToCohereEmbeddingTypeFloat() { public void testFromElementType_CovertsByteToCohereEmbeddingTypeByte() { assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BYTE), is(CohereEmbeddingType.BYTE)); } + + public void testFromElementType_ConvertsBitToCohereEmbeddingTypeBinary() { + assertThat(CohereEmbeddingType.fromElementType(DenseVectorFieldMapper.ElementType.BIT), is(CohereEmbeddingType.BIT)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java index 73ebd6c6c0505..544676cfa7cc7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -218,7 +218,7 @@ public void testFromMap_InvalidEmbeddingType_ThrowsError_ForRequest() { is( Strings.format( "Validation Failed: 1: [service_settings] Invalid value [abc] received. " - + "[embedding_type] must be one of [byte, float, int8];" + + "[embedding_type] must be one of [binary, bit, byte, float, int8];" ) ) ); @@ -238,7 +238,7 @@ public void testFromMap_InvalidEmbeddingType_ThrowsError_ForPersistent() { is( Strings.format( "Validation Failed: 1: [service_settings] Invalid value [abc] received. " - + "[embedding_type] must be one of [byte, float];" + + "[embedding_type] must be one of [bit, byte, float];" ) ) ); @@ -289,6 +289,16 @@ public void testFromMap_ConvertsInt8_ToCohereEmbeddingTypeInt8() { ); } + public void testFromMap_ConvertsBit_ToCohereEmbeddingTypeBit() { + assertThat( + CohereEmbeddingsServiceSettings.fromMap( + new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.BIT.toString())), + ConfigurationParseContext.REQUEST + ), + is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BIT)) + ); + } + public void testFromMap_PreservesEmbeddingTypeFloat() { assertThat( CohereEmbeddingsServiceSettings.fromMap( @@ -314,6 +324,8 @@ public void testFromCohereOrDenseVectorEnumValues() { assertEquals(CohereEmbeddingType.BYTE, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("byte", validation)); assertEquals(CohereEmbeddingType.INT8, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("int8", validation)); assertEquals(CohereEmbeddingType.FLOAT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("float", validation)); + assertEquals(CohereEmbeddingType.BINARY, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("binary", validation)); + assertEquals(CohereEmbeddingType.BIT, CohereEmbeddingsServiceSettings.fromCohereOrDenseVectorEnumValues("bit", validation)); assertTrue(validation.validationErrors().isEmpty()); } From 9b687c823b4127937b76663fdf0e132ce7611ef1 Mon Sep 17 00:00:00 2001 From: Ying Date: Thu, 23 Jan 2025 16:33:18 -0500 Subject: [PATCH 2/5] Returning response in separate text_embedding_bits field --- .../inference/service-cohere.asciidoc | 6 +- .../results/InferenceByteEmbedding.java | 95 ++++++++++++ .../InferenceTextEmbeddingBitResults.java | 109 ++++++++++++++ .../InferenceTextEmbeddingByteResults.java | 81 +---------- .../chunking/EmbeddingRequestChunker.java | 7 +- .../CohereEmbeddingsResponseEntity.java | 16 ++- .../EmbeddingRequestChunkerTests.java | 9 +- .../CohereEmbeddingsResponseEntityTests.java | 32 ++--- .../rest/BaseInferenceActionTests.java | 5 +- ...InferenceTextEmbeddingBitResultsTests.java | 135 ++++++++++++++++++ ...nferenceTextEmbeddingByteResultsTests.java | 26 ++-- .../results/TextEmbeddingResultsTests.java | 3 +- 12 files changed, 388 insertions(+), 136 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java diff --git a/docs/reference/inference/service-cohere.asciidoc b/docs/reference/inference/service-cohere.asciidoc index e95f0810fd29d..486054961043a 100644 --- a/docs/reference/inference/service-cohere.asciidoc +++ b/docs/reference/inference/service-cohere.asciidoc @@ -62,7 +62,7 @@ include::inference-shared.asciidoc[tag=chunking-settings-strategy] `service`:: (Required, string) -The type of service supported for the specified task type. In this case, +The type of service supported for the specified task type. In this case, `cohere`. `service_settings`:: @@ -127,6 +127,8 @@ Valid values are: * `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). * `float`: use it for the default float embeddings. * `int8`: use it for signed int8 embeddings. +* `binary`: use it for binary embeddings, which are encoded as bytes with signed int8 precision. +* `bit`: use it for binary embeddings, which are encoded as bytes with signed int8 precision (this is a synonym of `binary`). `model_id`::: (Optional, string) @@ -228,4 +230,4 @@ PUT _inference/rerank/cohere-rerank // TEST[skip:TBD] For more examples, also review the -https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. \ No newline at end of file +https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java new file mode 100644 index 0000000000000..242d16a17829a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt { + public static final String EMBEDDING = "embedding"; + + public InferenceByteEmbedding(StreamInput in) throws IOException { + this(in.readByteArray()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeByteArray(values); + } + + public static InferenceByteEmbedding of(List embeddingValuesList) { + byte[] embeddingValues = new byte[embeddingValuesList.size()]; + for (int i = 0; i < embeddingValuesList.size(); i++) { + embeddingValues[i] = embeddingValuesList.get(i); + } + return new InferenceByteEmbedding(embeddingValues); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.startArray(EMBEDDING); + for (byte value : values) { + builder.value(value); + } + builder.endArray(); + + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + float[] toFloatArray() { + float[] floatArray = new float[values.length]; + for (int i = 0; i < values.length; i++) { + floatArray[i] = ((Byte) values[i]).floatValue(); + } + return floatArray; + } + + double[] toDoubleArray() { + double[] doubleArray = new double[values.length]; + for (int i = 0; i < values.length; i++) { + doubleArray[i] = ((Byte) values[i]).floatValue(); + } + return doubleArray; + } + + @Override + public int getSize() { + return values().length; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceByteEmbedding embedding = (InferenceByteEmbedding) o; + return Arrays.equals(values, embedding.values); + } + + @Override + public int hashCode() { + return Arrays.hashCode(values); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java new file mode 100644 index 0000000000000..887c07558ab71 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingBitResults.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Writes a text embedding result in the follow json format + * { + * "text_embedding_bytes": [ + * { + * "embedding": [ + * 23 + * ] + * }, + * { + * "embedding": [ + * -23 + * ] + * } + * ] + * } + */ +public record InferenceTextEmbeddingBitResults(List embeddings) implements InferenceServiceResults, TextEmbedding { + public static final String NAME = "text_embedding_service_bit_results"; + public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits"; + + public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException { + this(in.readCollectionAsList(InferenceByteEmbedding::new)); + } + + @Override + public int getFirstEmbeddingSize() { + return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings)); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return ChunkedToXContentHelper.array(TEXT_EMBEDDING_BITS, embeddings.iterator()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(embeddings); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public List transformToCoordinationFormat() { + return embeddings.stream() + .map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false)) + .toList(); + } + + @Override + @SuppressWarnings("deprecation") + public List transformToLegacyFormat() { + var legacyEmbedding = new LegacyTextEmbeddingResults( + embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList() + ); + + return List.of(legacyEmbedding); + } + + public Map asMap() { + Map map = new LinkedHashMap<>(); + map.put(TEXT_EMBEDDING_BITS, embeddings); + + return map; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o; + return Objects.equals(embeddings, that.embeddings); + } + + @Override + public int hashCode() { + return Objects.hash(embeddings); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java index 16dca7b04d526..1ae54220508c5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java @@ -9,21 +9,16 @@ package org.elasticsearch.xpack.core.inference.results; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -33,7 +28,7 @@ /** * Writes a text embedding result in the follow json format * { - * "text_embedding": [ + * "text_embedding_bytes": [ * { * "embedding": [ * 23 @@ -111,78 +106,4 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(embeddings); } - - public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt { - public static final String EMBEDDING = "embedding"; - - public InferenceByteEmbedding(StreamInput in) throws IOException { - this(in.readByteArray()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeByteArray(values); - } - - public static InferenceByteEmbedding of(List embeddingValuesList) { - byte[] embeddingValues = new byte[embeddingValuesList.size()]; - for (int i = 0; i < embeddingValuesList.size(); i++) { - embeddingValues[i] = embeddingValuesList.get(i); - } - return new InferenceByteEmbedding(embeddingValues); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - builder.startArray(EMBEDDING); - for (byte value : values) { - builder.value(value); - } - builder.endArray(); - - builder.endObject(); - return builder; - } - - @Override - public String toString() { - return Strings.toString(this); - } - - private float[] toFloatArray() { - float[] floatArray = new float[values.length]; - for (int i = 0; i < values.length; i++) { - floatArray[i] = ((Byte) values[i]).floatValue(); - } - return floatArray; - } - - private double[] toDoubleArray() { - double[] doubleArray = new double[values.length]; - for (int i = 0; i < values.length; i++) { - doubleArray[i] = ((Byte) values[i]).floatValue(); - } - return doubleArray; - } - - @Override - public int getSize() { - return values().length; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - InferenceByteEmbedding embedding = (InferenceByteEmbedding) o; - return Arrays.equals(values, embedding.values); - } - - @Override - public int hashCode() { - return Arrays.hashCode(values); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 9b0b1104df660..fb796c2afdfeb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -69,7 +70,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private List chunkedOffsets; private List>> floatResults; - private List>> byteResults; + private List>> byteResults; private List>> sparseResults; private AtomicArray errors; private ActionListener> finalListener; @@ -389,9 +390,9 @@ private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs( private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs( ChunkOffsetsAndInput chunks, - AtomicArray> debatchedResults + AtomicArray> debatchedResults ) { - var all = new ArrayList(); + var all = new ArrayList(); for (int i = 0; i < debatchedResults.length(); i++) { var subBatch = debatchedResults.get(i); all.addAll(subBatch); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java index 2d338e43600d6..2e574d477b057 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntity.java @@ -17,6 +17,8 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -45,8 +47,7 @@ public class CohereEmbeddingsResponseEntity { toLowerCase(CohereEmbeddingType.INT8), CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray, toLowerCase(CohereEmbeddingType.BINARY), - // Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser - CohereEmbeddingsResponseEntity::parseByteEmbeddingsArray + CohereEmbeddingsResponseEntity::parseBitEmbeddingsArray ); private static final String VALID_EMBEDDING_TYPES_STRING = supportedEmbeddingTypes(); @@ -187,17 +188,24 @@ private static InferenceServiceResults parseEmbeddingsObject(XContentParser pars ); } + private static InferenceServiceResults parseBitEmbeddingsArray(XContentParser parser) throws IOException { + // Cohere returns array of binary embeddings encoded as bytes with int8 precision so we can reuse the byte parser + var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); + + return new InferenceTextEmbeddingBitResults(embeddingList); + } + private static InferenceServiceResults parseByteEmbeddingsArray(XContentParser parser) throws IOException { var embeddingList = parseList(parser, CohereEmbeddingsResponseEntity::parseByteArrayEntry); return new InferenceTextEmbeddingByteResults(embeddingList); } - private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { + private static InferenceByteEmbedding parseByteArrayEntry(XContentParser parser) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); List embeddingValuesList = parseList(parser, CohereEmbeddingsResponseEntity::parseEmbeddingInt8Entry); - return InferenceTextEmbeddingByteResults.InferenceByteEmbedding.of(embeddingValuesList); + return InferenceByteEmbedding.of(embeddingValuesList); } private static Byte parseEmbeddingInt8Entry(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 03249163c7f82..f0b82f49d4e98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; @@ -368,16 +369,16 @@ public void testMergingListener_Byte() { // 4 inputs in 2 batches { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < batchSize; i++) { - embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() })); + embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() })); } batches.get(0).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings)); } { - var embeddings = new ArrayList(); + var embeddings = new ArrayList(); for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch - embeddings.add(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { randomByte() })); + embeddings.add(new InferenceByteEmbedding(new byte[] { randomByte() })); } batches.get(1).listener().onResponse(new InferenceTextEmbeddingByteResults(embeddings)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java index 5b7231999f88e..42dab0a9021bf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereEmbeddingsResponseEntityTests.java @@ -10,6 +10,8 @@ import org.apache.http.HttpResponse; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpResult; @@ -182,10 +184,7 @@ public void testFromResponse_UsesTheFirstValidEmbeddingsEntryInt8_WithInvalidFir new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))) - ); + MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))); } public void testFromResponse_ParsesBytes() throws IOException { @@ -220,10 +219,7 @@ public void testFromResponse_ParsesBytes() throws IOException { new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); - MatcherAssert.assertThat( - parsedResults.embeddings(), - is(List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 }))) - ); + MatcherAssert.assertThat(parsedResults.embeddings(), is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1, (byte) 0 })))); } public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOException { @@ -256,20 +252,14 @@ public void testFromResponse_ParsesBytes_FromBinaryEmbeddingsEntry() throws IOEx } """; - InferenceTextEmbeddingByteResults parsedResults = (InferenceTextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); MatcherAssert.assertThat( parsedResults.embeddings(), - is( - List.of( - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding( - new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 } - ) - ) - ) + is(List.of(new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67, (byte) 83 }))) ); } @@ -402,7 +392,7 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary( } """; - InferenceTextEmbeddingByteResults parsedResults = (InferenceTextEmbeddingByteResults) CohereEmbeddingsResponseEntity.fromResponse( + InferenceTextEmbeddingBitResults parsedResults = (InferenceTextEmbeddingBitResults) CohereEmbeddingsResponseEntity.fromResponse( mock(Request.class), new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) ); @@ -411,12 +401,8 @@ public void testFromResponse_CreatesResultsForMultipleItems_ObjectFormat_Binary( parsedResults.embeddings(), is( List.of( - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding( - new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 } - ), - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding( - new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 } - ) + new InferenceByteEmbedding(new byte[] { (byte) -55, (byte) 74, (byte) 101, (byte) 67 }), + new InferenceByteEmbedding(new byte[] { (byte) 34, (byte) -64, (byte) 97, (byte) 65, (byte) -42 }) ) ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 5528c80066b0a..9e6718172a2f0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.test.rest.RestActionTestCase; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.junit.Before; @@ -142,9 +143,7 @@ public void testUses3SecondTimeoutFromParams() { static InferenceAction.Response createResponse() { return new InferenceAction.Response( - new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) -1 })) - ) + new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) -1 }))) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java new file mode 100644 index 0000000000000..45b9627371575 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingBitResultsTests.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingBitResults; +import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class InferenceTextEmbeddingBitResultsTests extends AbstractWireSerializingTestCase { + public static InferenceTextEmbeddingBitResults createRandomResults() { + int embeddings = randomIntBetween(1, 10); + List embeddingResults = new ArrayList<>(embeddings); + + for (int i = 0; i < embeddings; i++) { + embeddingResults.add(createRandomEmbedding()); + } + + return new InferenceTextEmbeddingBitResults(embeddingResults); + } + + private static InferenceByteEmbedding createRandomEmbedding() { + int columns = randomIntBetween(1, 10); + byte[] bytes = new byte[columns]; + + for (int i = 0; i < columns; i++) { + bytes[i] = randomByte(); + } + + return new InferenceByteEmbedding(bytes); + } + + public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { + var entity = new InferenceTextEmbeddingBitResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }))); + + String xContentResult = Strings.toString(entity, true, true); + assertThat(xContentResult, is(""" + { + "text_embedding_bits" : [ + { + "embedding" : [ + 23 + ] + } + ] + }""")); + } + + public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { + var entity = new InferenceTextEmbeddingBitResults( + List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 })) + ); + + String xContentResult = Strings.toString(entity, true, true); + assertThat(xContentResult, is(""" + { + "text_embedding_bits" : [ + { + "embedding" : [ + 23 + ] + }, + { + "embedding" : [ + 24 + ] + } + ] + }""")); + } + + public void testTransformToCoordinationFormat() { + var results = new InferenceTextEmbeddingBitResults( + List.of( + new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }), + new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 }) + ) + ).transformToCoordinationFormat(); + + assertThat( + results, + is( + List.of( + new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 23F, 24F }, false), + new MlTextEmbeddingResults(InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, new double[] { 25F, 26F }, false) + ) + ) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return InferenceTextEmbeddingBitResults::new; + } + + @Override + protected InferenceTextEmbeddingBitResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected InferenceTextEmbeddingBitResults mutateInstance(InferenceTextEmbeddingBitResults instance) throws IOException { + // if true we reduce the embeddings list by a random amount, if false we add an embedding to the list + if (randomBoolean()) { + // -1 to remove at least one item from the list + int end = randomInt(instance.embeddings().size() - 1); + return new InferenceTextEmbeddingBitResults(instance.embeddings().subList(0, end)); + } else { + List embeddings = new ArrayList<>(instance.embeddings()); + embeddings.add(createRandomEmbedding()); + return new InferenceTextEmbeddingBitResults(embeddings); + } + } + + public static Map buildExpectationByte(List> embeddings) { + return Map.of( + InferenceTextEmbeddingBitResults.TEXT_EMBEDDING_BITS, + embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java index c6749e9822cf4..d932f36fb25a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/InferenceTextEmbeddingByteResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -23,7 +24,7 @@ public class InferenceTextEmbeddingByteResultsTests extends AbstractWireSerializingTestCase { public static InferenceTextEmbeddingByteResults createRandomResults() { int embeddings = randomIntBetween(1, 10); - List embeddingResults = new ArrayList<>(embeddings); + List embeddingResults = new ArrayList<>(embeddings); for (int i = 0; i < embeddings; i++) { embeddingResults.add(createRandomEmbedding()); @@ -32,7 +33,7 @@ public static InferenceTextEmbeddingByteResults createRandomResults() { return new InferenceTextEmbeddingByteResults(embeddingResults); } - private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding createRandomEmbedding() { + private static InferenceByteEmbedding createRandomEmbedding() { int columns = randomIntBetween(1, 10); byte[] bytes = new byte[columns]; @@ -40,13 +41,11 @@ private static InferenceTextEmbeddingByteResults.InferenceByteEmbedding createRa bytes[i] = randomByte(); } - return new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(bytes); + return new InferenceByteEmbedding(bytes); } public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOException { - var entity = new InferenceTextEmbeddingByteResults( - List.of(new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 })) - ); + var entity = new InferenceTextEmbeddingByteResults(List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }))); String xContentResult = Strings.toString(entity, true, true); assertThat(xContentResult, is(""" @@ -63,10 +62,7 @@ public void testToXContent_CreatesTheRightFormatForASingleEmbedding() throws IOE public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws IOException { var entity = new InferenceTextEmbeddingByteResults( - List.of( - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23 }), - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 24 }) - ) + List.of(new InferenceByteEmbedding(new byte[] { (byte) 23 }), new InferenceByteEmbedding(new byte[] { (byte) 24 })) ); String xContentResult = Strings.toString(entity, true, true); @@ -90,8 +86,8 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I public void testTransformToCoordinationFormat() { var results = new InferenceTextEmbeddingByteResults( List.of( - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }), - new InferenceTextEmbeddingByteResults.InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 }) + new InferenceByteEmbedding(new byte[] { (byte) 23, (byte) 24 }), + new InferenceByteEmbedding(new byte[] { (byte) 25, (byte) 26 }) ) ).transformToCoordinationFormat(); @@ -124,7 +120,7 @@ protected InferenceTextEmbeddingByteResults mutateInstance(InferenceTextEmbeddin int end = randomInt(instance.embeddings().size() - 1); return new InferenceTextEmbeddingByteResults(instance.embeddings().subList(0, end)); } else { - List embeddings = new ArrayList<>(instance.embeddings()); + List embeddings = new ArrayList<>(instance.embeddings()); embeddings.add(createRandomEmbedding()); return new InferenceTextEmbeddingByteResults(embeddings); } @@ -133,9 +129,7 @@ protected InferenceTextEmbeddingByteResults mutateInstance(InferenceTextEmbeddin public static Map buildExpectationByte(List> embeddings) { return Map.of( InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream() - .map(embedding -> Map.of(InferenceTextEmbeddingByteResults.InferenceByteEmbedding.EMBEDDING, embedding)) - .toList() + embeddings.stream().map(embedding -> Map.of(InferenceByteEmbedding.EMBEDDING, embedding)).toList() ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 2c405aaeaba3f..56bd690a9cdbf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; @@ -141,7 +142,7 @@ public static Map buildExpectationFloat(List embeddings public static Map buildExpectationByte(List embeddings) { return Map.of( InferenceTextEmbeddingByteResults.TEXT_EMBEDDING_BYTES, - embeddings.stream().map(InferenceTextEmbeddingByteResults.InferenceByteEmbedding::new).toList() + embeddings.stream().map(InferenceByteEmbedding::new).toList() ); } From e6034d8cef970989724c147785fd1001a4c11796 Mon Sep 17 00:00:00 2001 From: Ying Mao Date: Thu, 30 Jan 2025 07:42:37 -0500 Subject: [PATCH 3/5] Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java Co-authored-by: David Kyle --- .../xpack/core/inference/results/InferenceByteEmbedding.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java index 242d16a17829a..7d7176a9a5a51 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java @@ -70,7 +70,7 @@ float[] toFloatArray() { double[] toDoubleArray() { double[] doubleArray = new double[values.length]; for (int i = 0; i < values.length; i++) { - doubleArray[i] = ((Byte) values[i]).floatValue(); + doubleArray[i] = ((Byte) values[i]).doubleValue(); } return doubleArray; } From b170eafb22a6245050f8e44a7b0a8174b7882981 Mon Sep 17 00:00:00 2001 From: Ying Mao Date: Thu, 30 Jan 2025 07:45:48 -0500 Subject: [PATCH 4/5] Update docs/changelog/120751.yaml --- docs/changelog/120751.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/120751.yaml diff --git a/docs/changelog/120751.yaml b/docs/changelog/120751.yaml new file mode 100644 index 0000000000000..0c1dffc0e527b --- /dev/null +++ b/docs/changelog/120751.yaml @@ -0,0 +1,5 @@ +pr: 120751 +summary: Adding support for binary embedding type to Cohere service embedding type +area: Machine Learning +type: enhancement +issues: [] From a13caa0374c7919913090d154c99b7edfc368178 Mon Sep 17 00:00:00 2001 From: Ying Date: Thu, 30 Jan 2025 09:03:23 -0500 Subject: [PATCH 5/5] Reverting docs change --- docs/reference/inference/service-cohere.asciidoc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/reference/inference/service-cohere.asciidoc b/docs/reference/inference/service-cohere.asciidoc index d9ba5b4e2505d..289f03787580f 100644 --- a/docs/reference/inference/service-cohere.asciidoc +++ b/docs/reference/inference/service-cohere.asciidoc @@ -62,7 +62,7 @@ include::inference-shared.asciidoc[tag=chunking-settings-strategy] `service`:: (Required, string) -The type of service supported for the specified task type. In this case, +The type of service supported for the specified task type. In this case, `cohere`. `service_settings`:: @@ -127,8 +127,6 @@ Valid values are: * `byte`: use it for signed int8 embeddings (this is a synonym of `int8`). * `float`: use it for the default float embeddings. * `int8`: use it for signed int8 embeddings. -* `binary`: use it for binary embeddings, which are encoded as bytes with signed int8 precision. -* `bit`: use it for binary embeddings, which are encoded as bytes with signed int8 precision (this is a synonym of `binary`). `model_id`::: (Optional, string) @@ -230,4 +228,4 @@ PUT _inference/rerank/cohere-rerank // TEST[skip:TBD] For more examples, also review the -https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. +https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation]. \ No newline at end of file