Skip to content

Commit e20fded

Browse files
authored
[8.x] Adding support for binary embedding type to Cohere service embedding type (#120751) (#121584)
* Adding support for binary embedding type to Cohere service embedding type (#120751) * Adding support for binary embedding type to Cohere service embedding type * Returning response in separate text_embedding_bits field * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceByteEmbedding.java Co-authored-by: David Kyle <[email protected]> * Update docs/changelog/120751.yaml * Reverting docs change --------- Co-authored-by: David Kyle <[email protected]> (cherry picked from commit 89d71e1) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java * Adding docs
1 parent a3ab655 commit e20fded

File tree

19 files changed

+706
-122
lines changed

19 files changed

+706
-122
lines changed

docs/changelog/120751.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 120751
2+
summary: Adding support for binary embedding type to Cohere service embedding type
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

docs/reference/inference/service-cohere.asciidoc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ include::inference-shared.asciidoc[tag=chunking-settings-strategy]
6262

6363
`service`::
6464
(Required, string)
65-
The type of service supported for the specified task type. In this case,
65+
The type of service supported for the specified task type. In this case,
6666
`cohere`.
6767

6868
`service_settings`::
@@ -127,6 +127,8 @@ Valid values are:
127127
* `byte`: use it for signed int8 embeddings (this is a synonym of `int8`).
128128
* `float`: use it for the default float embeddings.
129129
* `int8`: use it for signed int8 embeddings.
130+
* `binary`: use it for binary embeddings, which are encoded as bytes with signed int8 precision.
131+
* `bit`: use it for binary embeddings, which are encoded as bytes with signed int8 precision (this is a synonym of `binary`).
130132
131133
`model_id`:::
132134
(Optional, string)
@@ -228,4 +230,4 @@ PUT _inference/rerank/cohere-rerank
228230
// TEST[skip:TBD]
229231

230232
For more examples, also review the
231-
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].
233+
https://docs.cohere.com/docs/elasticsearch-and-cohere#rerank-search-results-with-cohere-and-elasticsearch[Cohere documentation].

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ static TransportVersion def(int id) {
178178
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_0_00);
179179
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_0_00);
180180
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_0_00);
181+
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_01);
181182

182183
/*
183184
* STOP! READ THIS FIRST! No, really,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*
7+
* this file was contributed to by a generative AI
8+
*/
9+
10+
package org.elasticsearch.xpack.core.inference.results;
11+
12+
import org.elasticsearch.common.Strings;
13+
import org.elasticsearch.common.io.stream.StreamInput;
14+
import org.elasticsearch.common.io.stream.StreamOutput;
15+
import org.elasticsearch.common.io.stream.Writeable;
16+
import org.elasticsearch.xcontent.ToXContentObject;
17+
import org.elasticsearch.xcontent.XContentBuilder;
18+
19+
import java.io.IOException;
20+
import java.util.Arrays;
21+
import java.util.List;
22+
23+
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
24+
public static final String EMBEDDING = "embedding";
25+
26+
public InferenceByteEmbedding(StreamInput in) throws IOException {
27+
this(in.readByteArray());
28+
}
29+
30+
@Override
31+
public void writeTo(StreamOutput out) throws IOException {
32+
out.writeByteArray(values);
33+
}
34+
35+
public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
36+
byte[] embeddingValues = new byte[embeddingValuesList.size()];
37+
for (int i = 0; i < embeddingValuesList.size(); i++) {
38+
embeddingValues[i] = embeddingValuesList.get(i);
39+
}
40+
return new InferenceByteEmbedding(embeddingValues);
41+
}
42+
43+
@Override
44+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
45+
builder.startObject();
46+
47+
builder.startArray(EMBEDDING);
48+
for (byte value : values) {
49+
builder.value(value);
50+
}
51+
builder.endArray();
52+
53+
builder.endObject();
54+
return builder;
55+
}
56+
57+
@Override
58+
public String toString() {
59+
return Strings.toString(this);
60+
}
61+
62+
float[] toFloatArray() {
63+
float[] floatArray = new float[values.length];
64+
for (int i = 0; i < values.length; i++) {
65+
floatArray[i] = ((Byte) values[i]).floatValue();
66+
}
67+
return floatArray;
68+
}
69+
70+
double[] toDoubleArray() {
71+
double[] doubleArray = new double[values.length];
72+
for (int i = 0; i < values.length; i++) {
73+
doubleArray[i] = ((Byte) values[i]).doubleValue();
74+
}
75+
return doubleArray;
76+
}
77+
78+
@Override
79+
public int getSize() {
80+
return values().length;
81+
}
82+
83+
@Override
84+
public boolean equals(Object o) {
85+
if (this == o) return true;
86+
if (o == null || getClass() != o.getClass()) return false;
87+
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
88+
return Arrays.equals(values, embedding.values);
89+
}
90+
91+
@Override
92+
public int hashCode() {
93+
return Arrays.hashCode(values);
94+
}
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*
7+
* this file was contributed to by a generative AI
8+
*/
9+
10+
package org.elasticsearch.xpack.core.inference.results;
11+
12+
import org.elasticsearch.common.io.stream.StreamInput;
13+
import org.elasticsearch.common.io.stream.StreamOutput;
14+
import org.elasticsearch.common.xcontent.ChunkedToXContent;
15+
import org.elasticsearch.inference.InferenceResults;
16+
import org.elasticsearch.inference.InferenceServiceResults;
17+
import org.elasticsearch.xcontent.ToXContent;
18+
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
19+
20+
import java.io.IOException;
21+
import java.util.ArrayList;
22+
import java.util.Iterator;
23+
import java.util.LinkedHashMap;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.Objects;
27+
28+
/**
29+
* Writes a text embedding result in the follow json format
30+
* {
31+
* "text_embedding_bytes": [
32+
* {
33+
* "embedding": [
34+
* 23
35+
* ]
36+
* },
37+
* {
38+
* "embedding": [
39+
* -23
40+
* ]
41+
* }
42+
* ]
43+
* }
44+
*/
45+
public record InferenceTextEmbeddingBitResults(List<InferenceByteEmbedding> embeddings) implements InferenceServiceResults, TextEmbedding {
46+
public static final String NAME = "text_embedding_service_bit_results";
47+
public static final String TEXT_EMBEDDING_BITS = "text_embedding_bits";
48+
49+
public InferenceTextEmbeddingBitResults(StreamInput in) throws IOException {
50+
this(in.readCollectionAsList(InferenceByteEmbedding::new));
51+
}
52+
53+
@Override
54+
public int getFirstEmbeddingSize() {
55+
return TextEmbeddingUtils.getFirstEmbeddingSize(new ArrayList<>(embeddings));
56+
}
57+
58+
@Override
59+
public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
60+
return ChunkedToXContent.builder(params).array(TEXT_EMBEDDING_BITS, embeddings.iterator());
61+
}
62+
63+
@Override
64+
public void writeTo(StreamOutput out) throws IOException {
65+
out.writeCollection(embeddings);
66+
}
67+
68+
@Override
69+
public String getWriteableName() {
70+
return NAME;
71+
}
72+
73+
@Override
74+
public List<? extends InferenceResults> transformToCoordinationFormat() {
75+
return embeddings.stream()
76+
.map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING_BITS, embedding.toDoubleArray(), false))
77+
.toList();
78+
}
79+
80+
@Override
81+
@SuppressWarnings("deprecation")
82+
public List<? extends InferenceResults> transformToLegacyFormat() {
83+
var legacyEmbedding = new LegacyTextEmbeddingResults(
84+
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
85+
);
86+
87+
return List.of(legacyEmbedding);
88+
}
89+
90+
public Map<String, Object> asMap() {
91+
Map<String, Object> map = new LinkedHashMap<>();
92+
map.put(TEXT_EMBEDDING_BITS, embeddings);
93+
94+
return map;
95+
}
96+
97+
@Override
98+
public boolean equals(Object o) {
99+
if (this == o) return true;
100+
if (o == null || getClass() != o.getClass()) return false;
101+
InferenceTextEmbeddingBitResults that = (InferenceTextEmbeddingBitResults) o;
102+
return Objects.equals(embeddings, that.embeddings);
103+
}
104+
105+
@Override
106+
public int hashCode() {
107+
return Objects.hash(embeddings);
108+
}
109+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/InferenceTextEmbeddingByteResults.java

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,16 @@
99

1010
package org.elasticsearch.xpack.core.inference.results;
1111

12-
import org.elasticsearch.common.Strings;
1312
import org.elasticsearch.common.io.stream.StreamInput;
1413
import org.elasticsearch.common.io.stream.StreamOutput;
15-
import org.elasticsearch.common.io.stream.Writeable;
1614
import org.elasticsearch.common.xcontent.ChunkedToXContent;
1715
import org.elasticsearch.inference.InferenceResults;
1816
import org.elasticsearch.inference.InferenceServiceResults;
1917
import org.elasticsearch.xcontent.ToXContent;
20-
import org.elasticsearch.xcontent.ToXContentObject;
21-
import org.elasticsearch.xcontent.XContentBuilder;
2218
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
2319

2420
import java.io.IOException;
2521
import java.util.ArrayList;
26-
import java.util.Arrays;
2722
import java.util.Iterator;
2823
import java.util.LinkedHashMap;
2924
import java.util.List;
@@ -33,7 +28,7 @@
3328
/**
3429
* Writes a text embedding result in the follow json format
3530
* {
36-
* "text_embedding": [
31+
* "text_embedding_bytes": [
3732
* {
3833
* "embedding": [
3934
* 23
@@ -111,78 +106,4 @@ public boolean equals(Object o) {
111106
public int hashCode() {
112107
return Objects.hash(embeddings);
113108
}
114-
115-
public record InferenceByteEmbedding(byte[] values) implements Writeable, ToXContentObject, EmbeddingInt {
116-
public static final String EMBEDDING = "embedding";
117-
118-
public InferenceByteEmbedding(StreamInput in) throws IOException {
119-
this(in.readByteArray());
120-
}
121-
122-
@Override
123-
public void writeTo(StreamOutput out) throws IOException {
124-
out.writeByteArray(values);
125-
}
126-
127-
public static InferenceByteEmbedding of(List<Byte> embeddingValuesList) {
128-
byte[] embeddingValues = new byte[embeddingValuesList.size()];
129-
for (int i = 0; i < embeddingValuesList.size(); i++) {
130-
embeddingValues[i] = embeddingValuesList.get(i);
131-
}
132-
return new InferenceByteEmbedding(embeddingValues);
133-
}
134-
135-
@Override
136-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
137-
builder.startObject();
138-
139-
builder.startArray(EMBEDDING);
140-
for (byte value : values) {
141-
builder.value(value);
142-
}
143-
builder.endArray();
144-
145-
builder.endObject();
146-
return builder;
147-
}
148-
149-
@Override
150-
public String toString() {
151-
return Strings.toString(this);
152-
}
153-
154-
private float[] toFloatArray() {
155-
float[] floatArray = new float[values.length];
156-
for (int i = 0; i < values.length; i++) {
157-
floatArray[i] = ((Byte) values[i]).floatValue();
158-
}
159-
return floatArray;
160-
}
161-
162-
private double[] toDoubleArray() {
163-
double[] doubleArray = new double[values.length];
164-
for (int i = 0; i < values.length; i++) {
165-
doubleArray[i] = ((Byte) values[i]).floatValue();
166-
}
167-
return doubleArray;
168-
}
169-
170-
@Override
171-
public int getSize() {
172-
return values().length;
173-
}
174-
175-
@Override
176-
public boolean equals(Object o) {
177-
if (this == o) return true;
178-
if (o == null || getClass() != o.getClass()) return false;
179-
InferenceByteEmbedding embedding = (InferenceByteEmbedding) o;
180-
return Arrays.equals(values, embedding.values);
181-
}
182-
183-
@Override
184-
public int hashCode() {
185-
return Arrays.hashCode(values);
186-
}
187-
}
188109
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingFloat;
2020
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbeddingSparse;
2121
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
22+
import org.elasticsearch.xpack.core.inference.results.InferenceByteEmbedding;
2223
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
2324
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
2425
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
@@ -69,7 +70,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El
6970

7071
private List<ChunkOffsetsAndInput> chunkedOffsets;
7172
private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
72-
private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
73+
private List<AtomicArray<List<InferenceByteEmbedding>>> byteResults;
7374
private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
7475
private AtomicArray<Exception> errors;
7576
private ActionListener<List<ChunkedInference>> finalListener;
@@ -389,9 +390,9 @@ private ChunkedInferenceEmbeddingFloat mergeFloatResultsWithInputs(
389390

390391
private ChunkedInferenceEmbeddingByte mergeByteResultsWithInputs(
391392
ChunkOffsetsAndInput chunks,
392-
AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults
393+
AtomicArray<List<InferenceByteEmbedding>> debatchedResults
393394
) {
394-
var all = new ArrayList<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>();
395+
var all = new ArrayList<InferenceByteEmbedding>();
395396
for (int i = 0; i < debatchedResults.length(); i++) {
396397
var subBatch = debatchedResults.get(i);
397398
all.addAll(subBatch);

0 commit comments

Comments
 (0)