diff --git a/docs/changelog/105393.yaml b/docs/changelog/105393.yaml new file mode 100644 index 0000000000000..4a4cc299b7bd7 --- /dev/null +++ b/docs/changelog/105393.yaml @@ -0,0 +1,5 @@ +pr: 105393 +summary: Adding support for hex-encoded byte vectors on knn-search +area: Vector Search +type: feature +issues: [] diff --git a/docs/reference/query-dsl/knn-query.asciidoc b/docs/reference/query-dsl/knn-query.asciidoc index e9aeea68c06f7..c11782f524950 100644 --- a/docs/reference/query-dsl/knn-query.asciidoc +++ b/docs/reference/query-dsl/knn-query.asciidoc @@ -87,8 +87,8 @@ the top `size` results. `query_vector`:: + -- -(Required, array of floats) Query vector. Must have the same number of dimensions -as the vector field you are searching against. +(Required, array of floats or string) Query vector. Must have the same number of dimensions +as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector. -- `num_candidates`:: diff --git a/docs/reference/rest-api/common-parms.asciidoc b/docs/reference/rest-api/common-parms.asciidoc index 6757b6be24207..062f832b6f79d 100644 --- a/docs/reference/rest-api/common-parms.asciidoc +++ b/docs/reference/rest-api/common-parms.asciidoc @@ -597,7 +597,7 @@ end::knn-num-candidates[] tag::knn-query-vector[] Query vector. Must have the same number of dimensions as the vector field you -are searching against. +are searching against. Must be either an array of floats or a hex-encoded byte vector. end::knn-query-vector[] tag::knn-similarity[] diff --git a/docs/reference/search/knn-search.asciidoc b/docs/reference/search/knn-search.asciidoc index 136b53388baf9..7947c688a807c 100644 --- a/docs/reference/search/knn-search.asciidoc +++ b/docs/reference/search/knn-search.asciidoc @@ -121,7 +121,7 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-k] include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates] `query_vector`:: -(Required, array of floats) +(Required, array of floats or string) include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector] ==== diff --git a/docs/reference/search/search-your-data/knn-search.asciidoc b/docs/reference/search/search-your-data/knn-search.asciidoc index ab65b834c0ce7..030c10a91d005 100644 --- a/docs/reference/search/search-your-data/knn-search.asciidoc +++ b/docs/reference/search/search-your-data/knn-search.asciidoc @@ -242,6 +242,27 @@ POST byte-image-index/_search // TEST[s/"k": 10/"k": 3/] // TEST[s/"num_candidates": 100/"num_candidates": 3/] + +_Note_: In addition to the standard byte array, one can also provide a hex-encoded string value +for the `query_vector` param. As an example, the search request above can also be expressed as follows, +which would yield the same results +[source,console] +---- +POST byte-image-index/_search +{ + "knn": { + "field": "byte-image-vector", + "query_vector": "fb09", + "k": 10, + "num_candidates": 100 + }, + "fields": [ "title" ] +} +---- +// TEST[continued] +// TEST[s/"k": 10/"k": 3/] +// TEST[s/"num_candidates": 100/"num_candidates": 3/] + [discrete] [[knn-search-quantized-example]] ==== Byte quantized kNN search diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml new file mode 100644 index 0000000000000..71f65220eba1e --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml @@ -0,0 +1,163 @@ +setup: + - skip: + version: ' - 8.13.99' + reason: 'hex encoding for byte vectors was added in 8.14' + + - do: + indices.create: + index: knn_hex_vector_index + body: + settings: + number_of_shards: 1 + mappings: + dynamic: false + properties: + my_vector_byte: + type: dense_vector + dims: 3 + index : true + similarity : l2_norm + element_type: byte + my_vector_float: + type: dense_vector + dims: 3 + index: true + element_type: float + similarity : l2_norm + + # [-128, 127, 10] - is encoded as '807f0a' + - do: + index: + index: knn_hex_vector_index + id: "1" + body: + my_vector_byte: "807f0a" + + + # [0, 1, 0] - is encoded as '000100' + - do: + index: + index: knn_hex_vector_index + id: "2" + body: + my_vector_byte: "000100" + + # [64, -10, -30] - is encoded as '40f6e2' + - do: + index: + index: knn_hex_vector_index + id: "3" + body: + my_vector_byte: "40f6e2" + + - do: + index: + index: knn_hex_vector_index + id: "4" + body: + my_vector_float: [10.5, -10, 1024] + + - do: + indices.refresh: {} + +--- +"Fail to index hex-encoded vector on float field": + + # [-128, 127, 10] - is encoded as '807f0a' + - do: + catch: /Failed to parse object./ + index: + index: knn_hex_vector_index + id: "5" + body: + my_vector_float: "807f0a" + +--- +"Knn search with hex string for float field" : + # [64, 10, -30] - is encoded as '400ae2' + # this will be properly decoded but only because: + # (i) the provided input is compatible as the values are within [Byte.MIN_VALUE, Byte.MAX_VALUE] range + # (ii) we do not differentiate between byte and float fields when initially parsing a query even for hex + # (iii) we support expansion from byte to float + + - do: + search: + index: knn_hex_vector_index + body: + size: 3 + knn: + field: my_vector_float + query_vector: "400ae2" + num_candidates: 100 + k: 10 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "4" } + +--- +"Knn search with hex string for byte field" : + # [64, 10, -30] - is encoded as '400ae2' + - do: + search: + index: knn_hex_vector_index + body: + size: 3 + knn: + field: my_vector_byte + query_vector: "400ae2" + num_candidates: 100 + k: 10 + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.2._id: "1" } + +--- +"Knn search with hex string for byte field - dimensions mismatch" : + # [64, 10, -30, 10] - is encoded as '400ae20a' + - do: + catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/ + search: + index: knn_hex_vector_index + body: + size: 3 + knn: + field: my_vector_byte + query_vector: "400ae20a" + num_candidates: 100 + k: 10 + + +--- +"Knn search with hex string for byte field - cannot decode string" : + # '40af20a' is garbage :) + - do: + catch: /failed to parse field \[query_vector\]/ + search: + index: knn_hex_vector_index + body: + size: 3 + knn: + field: my_vector_byte + query_vector: "40af20a" + num_candidates: 100 + k: 10 + +--- +"Knn search with standard byte vector matching against hex-encoded indexed docs" : + - do: + search: + index: knn_hex_vector_index + body: + size: 3 + knn: + field: my_vector_byte + query_vector: [64, 10, -30] + num_candidates: 100 + k: 10 + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.2._id: "1" } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml new file mode 100644 index 0000000000000..9f850400a09cd --- /dev/null +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml @@ -0,0 +1,162 @@ +setup: + - skip: + version: ' - 8.13.99' + reason: 'hex encoding for byte vectors was added in 8.14' + + - do: + indices.create: + index: knn_hex_vector_index + body: + settings: + number_of_shards: 1 + mappings: + dynamic: false + properties: + my_vector_byte: + type: dense_vector + dims: 3 + index : true + similarity : l2_norm + element_type: byte + my_vector_float: + type: dense_vector + dims: 3 + index: true + element_type: float + similarity : l2_norm + + # [-128, 127, 10] - is encoded as '807f0a' + - do: + index: + index: knn_hex_vector_index + id: "1" + body: + my_vector_byte: "807f0a" + + + # [0, 1, 0] - is encoded as '000100' + - do: + index: + index: knn_hex_vector_index + id: "2" + body: + my_vector_byte: "000100" + + # [64, -10, -30] - is encoded as '40f6e2' + - do: + index: + index: knn_hex_vector_index + id: "3" + body: + my_vector_byte: "40f6e2" + + - do: + index: + index: knn_hex_vector_index + id: "4" + body: + my_vector_float: [10.5, -10, 1024] + + - do: + indices.refresh: {} + +--- +"Fail to index hex-encoded vector on float field": + + # [-128, 127, 10] - is encoded as '807f0a' + - do: + catch: /Failed to parse object./ + index: + index: knn_hex_vector_index + id: "5" + body: + my_vector_float: "807f0a" + +--- +"Knn query with hex string for float field" : + # [64, 10, -30] - is encoded as '400ae2' + # this will be properly decoded but only because: + # (i) the provided input is compatible as the values are within [Byte.MIN_VALUE, Byte.MAX_VALUE] range + # (ii) we do not differentiate between byte and float fields when initially parsing a query even for hex + # (iii) we support expansion from byte to float + + - do: + search: + index: knn_hex_vector_index + body: + size: 3 + query: + knn: + field: my_vector_float + query_vector: "400ae2" + num_candidates: 100 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "4" } + +--- +"Knn query with hex string for byte field" : + # [64, 10, -30] - is encoded as '400ae2' + - do: + search: + index: knn_hex_vector_index + body: + size: 3 + query: + knn: + field: my_vector_byte + query_vector: "400ae2" + num_candidates: 100 + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.2._id: "1" } + +--- +"Knn query with hex string for byte field - dimensions mismatch" : + # [64, 10, -30, 10] - is encoded as '400ae20a' + - do: + catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/ + search: + index: knn_hex_vector_index + body: + size: 3 + query: + knn: + field: my_vector_byte + query_vector: "400ae20a" + num_candidates: 100 + +--- +"Knn query with hex string for byte field - cannot decode string" : + # '40af20a' is garbage :) + - do: + catch: /failed to parse field \[query_vector\]/ + search: + index: knn_hex_vector_index + body: + size: 3 + query: + knn: + field: my_vector_byte + query_vector: "40af20a" + num_candidates: 100 + +--- +"Knn query with standard byte vector matching against hex-encoded indexed docs" : + - do: + search: + index: knn_hex_vector_index + body: + size: 3 + query: + knn: + field: my_vector_byte + query_vector: [64, 10, -30] + num_candidates: 100 + + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "3" } + - match: { hits.hits.1._id: "2" } + - match: { hits.hits.2._id: "1" } diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 6ac2c24739805..a83b0ea0c90e5 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -143,6 +143,7 @@ static TransportVersion def(int id) { public static final TransportVersion ADD_DATA_STREAM_GLOBAL_RETENTION = def(8_603_00_0); public static final TransportVersion ALLOCATION_STATS = def(8_604_00_0); public static final TransportVersion ESQL_EXTENDED_ENRICH_TYPES = def(8_605_00_0); + public static final TransportVersion KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING = def(8_606_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java index 7281616a8d25f..c4952b8cae51d 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java @@ -693,6 +693,20 @@ public byte[] readOptionalByteArray() throws IOException { return null; } + /** + * Reads an optional float array. It's effectively the same as readFloatArray, except + * it supports null. + * @return a float array or null + * @throws IOException + */ + @Nullable + public float[] readOptionalFloatArray() throws IOException { + if (readBoolean()) { + return readFloatArray(); + } + return null; + } + /** * Same as {@link #readMap(Writeable.Reader, Writeable.Reader)} but always reading string keys. */ diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index 69a5135215eba..33fb000c1bca2 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -534,6 +534,19 @@ public void writeOptionalByteArray(@Nullable byte[] array) throws IOException { } } + /** + * Writes a float array, for null arrays it writes false. + * @param array an array or null + */ + public void writeOptionalFloatArray(@Nullable float[] array) throws IOException { + if (array == null) { + writeBoolean(false); + } else { + writeBoolean(true); + writeFloatArray(array); + } + } + public void writeGenericMap(@Nullable Map map) throws IOException { writeGenericValue(map); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 47efa0ca49771..22b8549e14969 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -43,6 +43,7 @@ import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; @@ -69,6 +70,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery; import org.elasticsearch.search.vectors.ESKnnByteVectorQuery; import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery; +import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.search.vectors.VectorSimilarityQuery; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; @@ -80,6 +82,7 @@ import java.nio.ByteOrder; import java.time.ZoneId; import java.util.Arrays; +import java.util.HexFormat; import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -88,6 +91,7 @@ import java.util.function.Supplier; import java.util.stream.Stream; +import static org.elasticsearch.common.Strings.format; import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; /** @@ -338,11 +342,16 @@ void checkVectorMagnitude( } @Override - public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + public double computeDotProduct(VectorData vectorData) { + return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector()); + } + + private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { int index = 0; byte[] vector = new byte[fieldMapper.fieldType().dims]; float squaredMagnitude = 0; - for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { + for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser() + .nextToken()) { fieldMapper.checkDimensionExceeded(index, context); ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); final int value; @@ -383,44 +392,49 @@ public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFie } fieldMapper.checkDimensionMatches(index, context); checkVectorMagnitude(fieldMapper.fieldType().similarity, errorByteElementsAppender(vector), squaredMagnitude); + return VectorData.fromBytes(vector); + } + + private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + byte[] decodedVector = HexFormat.of().parseHex(context.parser().text()); + fieldMapper.checkDimensionMatches(decodedVector.length, context); + VectorData vectorData = VectorData.fromBytes(decodedVector); + double squaredMagnitude = computeDotProduct(vectorData); + checkVectorMagnitude( + fieldMapper.fieldType().similarity, + errorByteElementsAppender(decodedVector), + (float) squaredMagnitude + ); + return vectorData; + } + + @Override + VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + XContentParser.Token token = context.parser().currentToken(); + return switch (token) { + case START_ARRAY -> parseVectorArray(context, fieldMapper); + case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper); + default -> throw new ParsingException( + context.parser().getTokenLocation(), + format("Unsupported type [%s] for provided value [%s]", token, context.parser().text()) + ); + }; + } + + @Override + public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { + VectorData vectorData = parseKnnVector(context, fieldMapper); Field field = createKnnVectorField( fieldMapper.fieldType().name(), - vector, + vectorData.asByteVector(), fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this) ); context.doc().addWithKey(fieldMapper.fieldType().name(), field); } @Override - double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer) - throws IOException { - double dotProduct = 0f; - int index = 0; - for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { - fieldMapper.checkDimensionExceeded(index, context); - ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); - int value = context.parser().intValue(true); - if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { - throw new IllegalArgumentException( - "element_type [" - + this - + "] vectors only support integers between [" - + Byte.MIN_VALUE - + ", " - + Byte.MAX_VALUE - + "] but found [" - + value - + "] at dim [" - + index - + "];" - ); - } - byteBuffer.put((byte) value); - dotProduct += value * value; - index++; - } - fieldMapper.checkDimensionMatches(index, context); - return dotProduct; + int getNumBytes(int dimensions) { + return dimensions * elementBytes; } @Override @@ -530,6 +544,11 @@ void checkVectorMagnitude( } } + @Override + public double computeDotProduct(VectorData vectorData) { + return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector()); + } + @Override public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { int index = 0; @@ -566,23 +585,27 @@ && isNotUnitVector(squaredMagnitude)) { } @Override - double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer) - throws IOException { - double dotProduct = 0f; + VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException { int index = 0; + float squaredMagnitude = 0; float[] vector = new float[fieldMapper.fieldType().dims]; for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { fieldMapper.checkDimensionExceeded(index, context); ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()); float value = context.parser().floatValue(true); vector[index] = value; - byteBuffer.putFloat(value); - dotProduct += value * value; + squaredMagnitude += value * value; index++; } fieldMapper.checkDimensionMatches(index, context); checkVectorBounds(vector); - return dotProduct; + checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude); + return VectorData.fromFloats(vector); + } + + @Override + int getNumBytes(int dimensions) { + return dimensions * elementBytes; } @Override @@ -607,8 +630,9 @@ ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes) { abstract void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException; - abstract double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer) - throws IOException; + abstract VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException; + + abstract int getNumBytes(int dimensions); abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes); @@ -699,6 +723,8 @@ static Function errorFloatElementsAppender(float[] static Function errorByteElementsAppender(byte[] vector) { return sb -> appendErrorElements(sb, vector); } + + public abstract double computeDotProduct(VectorData vectorData); } static final Map namesToElementType = Map.of( @@ -1158,66 +1184,120 @@ public Query createKnnQuery( return knnQuery; } - public Query createExactKnnQuery(float[] queryVector) { - queryVector = validateAndNormalize(queryVector); - VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType); + public Query createExactKnnQuery(VectorData queryVector) { + if (isIndexed() == false) { + throw new IllegalArgumentException( + "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" + ); + } return switch (elementType) { - case BYTE -> { - byte[] bytes = new byte[queryVector.length]; + case BYTE -> createExactKnnByteQuery(queryVector.asByteVector()); + case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector()); + }; + } + + private Query createExactKnnByteQuery(byte[] queryVector) { + if (queryVector.length != dims) { + throw new IllegalArgumentException( + "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" + ); + } + if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { + float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); + elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); + } + VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType); + return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER) + .add( + new FunctionQuery( + new ByteVectorSimilarityFunction( + vectorSimilarityFunction, + new ByteKnnVectorFieldSource(name()), + new ConstKnnByteVectorValueSource(queryVector) + ) + ), + BooleanClause.Occur.SHOULD + ) + .build(); + } + + private Query createExactKnnFloatQuery(float[] queryVector) { + if (queryVector.length != dims) { + throw new IllegalArgumentException( + "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" + ); + } + elementType.checkVectorBounds(queryVector); + if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { + float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); + elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); + if (similarity == VectorSimilarity.COSINE + && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) + && isNotUnitVector(squaredMagnitude)) { + float length = (float) Math.sqrt(squaredMagnitude); + queryVector = Arrays.copyOf(queryVector, queryVector.length); for (int i = 0; i < queryVector.length; i++) { - bytes[i] = (byte) queryVector[i]; + queryVector[i] /= length; } - yield new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER) - .add( - new FunctionQuery( - new ByteVectorSimilarityFunction( - vectorSimilarityFunction, - new ByteKnnVectorFieldSource(name()), - new ConstKnnByteVectorValueSource(bytes) - ) - ), - BooleanClause.Occur.SHOULD - ) - .build(); } - case FLOAT -> new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER) - .add( - new FunctionQuery( - new FloatVectorSimilarityFunction( - vectorSimilarityFunction, - new FloatKnnVectorFieldSource(name()), - new ConstKnnFloatValueSource(queryVector) - ) - ), - BooleanClause.Occur.SHOULD - ) - .build(); - }; + } + VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType); + return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER) + .add( + new FunctionQuery( + new FloatVectorSimilarityFunction( + vectorSimilarityFunction, + new FloatKnnVectorFieldSource(name()), + new ConstKnnFloatValueSource(queryVector) + ) + ), + BooleanClause.Occur.SHOULD + ) + .build(); + } + + Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter) { + return createKnnQuery(VectorData.fromFloats(queryVector), numCands, filter, similarityThreshold, parentFilter); } public Query createKnnQuery( - float[] queryVector, + VectorData queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter ) { - queryVector = validateAndNormalize(queryVector); - Query knnQuery = switch (elementType) { - case BYTE -> { - byte[] bytes = new byte[queryVector.length]; - for (int i = 0; i < queryVector.length; i++) { - bytes[i] = (byte) queryVector[i]; - } - yield parentFilter != null - ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), bytes, filter, numCands, parentFilter) - : new ESKnnByteVectorQuery(name(), bytes, numCands, filter); - } - case FLOAT -> parentFilter != null - ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter) - : new ESKnnFloatVectorQuery(name(), queryVector, numCands, filter); + if (isIndexed() == false) { + throw new IllegalArgumentException( + "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" + ); + } + return switch (getElementType()) { + case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter); + case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter); }; + } + + private Query createKnnByteQuery( + byte[] queryVector, + int numCands, + Query filter, + Float similarityThreshold, + BitSetProducer parentFilter + ) { + if (queryVector.length != dims) { + throw new IllegalArgumentException( + "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" + ); + } + if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) { + float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); + elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude); + } + Query knnQuery = parentFilter != null + ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter) + : new ESKnnByteVectorQuery(name(), queryVector, numCands, filter); if (similarityThreshold != null) { knnQuery = new VectorSimilarityQuery( knnQuery, @@ -1228,12 +1308,13 @@ public Query createKnnQuery( return knnQuery; } - private float[] validateAndNormalize(float[] queryVector) { - if (isIndexed() == false) { - throw new IllegalArgumentException( - "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]" - ); - } + private Query createKnnFloatQuery( + float[] queryVector, + int numCands, + Query filter, + Float similarityThreshold, + BitSetProducer parentFilter + ) { if (queryVector.length != dims) { throw new IllegalArgumentException( "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]" @@ -1244,7 +1325,6 @@ private float[] validateAndNormalize(float[] queryVector) { float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector); elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude); if (similarity == VectorSimilarity.COSINE - && ElementType.FLOAT.equals(elementType) && indexVersionCreated.onOrAfter(NORMALIZE_COSINE) && isNotUnitVector(squaredMagnitude)) { float length = (float) Math.sqrt(squaredMagnitude); @@ -1254,7 +1334,17 @@ && isNotUnitVector(squaredMagnitude)) { } } } - return queryVector; + Query knnQuery = parentFilter != null + ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter) + : new ESKnnFloatVectorQuery(name(), queryVector, numCands, filter); + if (similarityThreshold != null) { + knnQuery = new VectorSimilarityQuery( + knnQuery, + similarityThreshold, + similarity.score(similarityThreshold, elementType, dims) + ); + } + return knnQuery; } VectorSimilarity getSimilarity() { @@ -1349,13 +1439,15 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t int dims = fieldType().dims; ElementType elementType = fieldType().elementType; int numBytes = indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION) - ? dims * elementType.elementBytes + MAGNITUDE_BYTES - : dims * elementType.elementBytes; + ? elementType.getNumBytes(dims) + MAGNITUDE_BYTES + : elementType.getNumBytes(dims); ByteBuffer byteBuffer = elementType.createByteBuffer(indexCreatedVersion, numBytes); - double dotProduct = elementType.parseKnnVectorToByteBuffer(context, this, byteBuffer); + VectorData vectorData = elementType.parseKnnVector(context, this); + vectorData.addToBuffer(byteBuffer); if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) { // encode vector magnitude at the end + double dotProduct = elementType.computeDotProduct(vectorData); float vectorMagnitude = (float) Math.sqrt(dotProduct); byteBuffer.putFloat(vectorMagnitude); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index fc2d4218ea1ec..3c4355e56d21d 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -13,6 +13,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.VectorData; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; @@ -121,7 +122,14 @@ public String getName() { @Override public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(field, queryVector, queryVectorBuilder, k, numCands, similarity); + KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder( + field, + VectorData.fromFloats(queryVector), + queryVectorBuilder, + k, + numCands, + similarity + ); if (preFilterQueryBuilders != null) { knnSearchBuilder.addFilterQueries(preFilterQueryBuilders); } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java index d292f61dcb085..60b0d259961da 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java @@ -22,7 +22,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Arrays; import java.util.Objects; /** @@ -32,7 +31,7 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "exact_knn"; private final String field; - private final float[] query; + private final VectorData query; /** * Creates a query builder. @@ -41,13 +40,27 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>("knn", args -> { // TODO optimize parsing for when BYTE values are provided - List vector = (List) args[1]; - final float[] vectorArray; - if (vector != null) { - vectorArray = new float[vector.size()]; - for (int i = 0; i < vector.size(); i++) { - vectorArray[i] = vector.get(i); - } - } else { - vectorArray = null; - } return new Builder().field((String) args[0]) - .queryVector(vectorArray) + .queryVector((VectorData) args[1]) .queryVectorBuilder((QueryVectorBuilder) args[4]) .k((Integer) args[2]) .numCandidates((Integer) args[3]) @@ -79,9 +68,15 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea static { PARSER.declareString(constructorArg(), FIELD_FIELD); - PARSER.declareFloatArray(optionalConstructorArg(), QUERY_VECTOR_FIELD); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> VectorData.parseXContent(p), + QUERY_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER + ); PARSER.declareInt(optionalConstructorArg(), K_FIELD); PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); + PARSER.declareNamedObject( optionalConstructorArg(), (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c), @@ -108,7 +103,7 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw } final String field; - final float[] queryVector; + final VectorData queryVector; final QueryVectorBuilder queryVectorBuilder; private final Supplier querySupplier; final int k; @@ -127,7 +122,26 @@ public static KnnSearchBuilder.Builder fromXContent(XContentParser parser) throw * @param numCands the number of nearest neighbor candidates to consider per shard */ public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) { - this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, similarity); + this( + field, + Objects.requireNonNull(VectorData.fromFloats(queryVector), format("[%s] cannot be null", QUERY_VECTOR_FIELD)), + null, + k, + numCands, + similarity + ); + } + + /** + * Defines a kNN search. + * + * @param field the name of the vector field to search against + * @param queryVector the query vector + * @param k the final number of nearest neighbors to return as top hits + * @param numCands the number of nearest neighbor candidates to consider per shard + */ + public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, Float similarity) { + this(field, queryVector, null, k, numCands, similarity); } /** @@ -151,7 +165,7 @@ public KnnSearchBuilder(String field, QueryVectorBuilder queryVectorBuilder, int public KnnSearchBuilder( String field, - float[] queryVector, + VectorData queryVector, QueryVectorBuilder queryVectorBuilder, int k, int numCands, @@ -169,7 +183,7 @@ private KnnSearchBuilder( Float similarity ) { this.field = field; - this.queryVector = new float[0]; + this.queryVector = VectorData.fromFloats(new float[0]); this.queryVectorBuilder = null; this.k = k; this.numCands = numCands; @@ -181,7 +195,7 @@ private KnnSearchBuilder( private KnnSearchBuilder( String field, QueryVectorBuilder queryVectorBuilder, - float[] queryVector, + VectorData queryVector, List filterQueries, int k, int numCandidates, @@ -219,7 +233,7 @@ private KnnSearchBuilder( ); } this.field = field; - this.queryVector = queryVector == null ? new float[0] : queryVector; + this.queryVector = queryVector == null ? VectorData.fromFloats(new float[0]) : queryVector; this.queryVectorBuilder = queryVectorBuilder; this.k = k; this.numCands = numCandidates; @@ -234,7 +248,11 @@ public KnnSearchBuilder(StreamInput in) throws IOException { this.field = in.readString(); this.k = in.readVInt(); this.numCands = in.readVInt(); - this.queryVector = in.readFloatArray(); + if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) { + this.queryVector = in.readOptionalWriteable(VectorData::new); + } else { + this.queryVector = VectorData.fromFloats(in.readFloatArray()); + } this.filterQueries = in.readNamedWriteableCollectionAsList(QueryBuilder.class); this.boost = in.readFloat(); if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0)) { @@ -262,7 +280,7 @@ public QueryVectorBuilder getQueryVectorBuilder() { } // for testing only - public float[] getQueryVector() { + public VectorData getQueryVector() { return queryVector; } @@ -365,7 +383,7 @@ public boolean equals(Object o) { return k == that.k && numCands == that.numCands && Objects.equals(field, that.field) - && Arrays.equals(queryVector, that.queryVector) + && Objects.equals(queryVector, that.queryVector) && Objects.equals(queryVectorBuilder, that.queryVectorBuilder) && Objects.equals(querySupplier, that.querySupplier) && Objects.equals(filterQueries, that.filterQueries) @@ -383,7 +401,7 @@ public int hashCode() { querySupplier, queryVectorBuilder, similarity, - Arrays.hashCode(queryVector), + Objects.hashCode(queryVector), Objects.hashCode(filterQueries), innerHitBuilder, boost @@ -401,7 +419,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder); builder.endObject(); } else { - builder.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); + builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector); } if (similarity != null) { builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity); @@ -434,7 +452,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(field); out.writeVInt(k); out.writeVInt(numCands); - out.writeFloatArray(queryVector); + if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) { + out.writeOptionalWriteable(queryVector); + } else { + out.writeFloatArray(queryVector.asFloatVector()); + } out.writeNamedWriteableCollection(filterQueries); out.writeFloat(boost); if (out.getTransportVersion().before(TransportVersions.V_8_7_0) && queryVectorBuilder != null) { @@ -460,7 +482,7 @@ public void writeTo(StreamOutput out) throws IOException { public static class Builder { private String field; - private float[] queryVector; + private VectorData queryVector; private QueryVectorBuilder queryVectorBuilder; private Integer k; private Integer numCandidates; @@ -490,7 +512,7 @@ public Builder innerHit(InnerHitBuilder innerHitBuilder) { return this; } - public Builder queryVector(float[] queryVector) { + public Builder queryVector(VectorData queryVector) { this.queryVector = queryVector; return this; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 7e65cd19638ce..149dedd59df46 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -37,7 +37,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -62,23 +61,19 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder PARSER = new ConstructingObjectParser<>("knn", args -> { - List vector = (List) args[1]; - final float[] vectorArray; - if (vector != null) { - vectorArray = new float[vector.size()]; - for (int i = 0; i < vector.size(); i++) { - vectorArray[i] = vector.get(i); - } - } else { - vectorArray = null; - } - return new KnnVectorQueryBuilder((String) args[0], vectorArray, (Integer) args[2], (Float) args[3]); - }); + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "knn", + args -> new KnnVectorQueryBuilder((String) args[0], (VectorData) args[1], (Integer) args[2], (Float) args[3]) + ); static { PARSER.declareString(constructorArg(), FIELD_FIELD); - PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> VectorData.parseXContent(p), + QUERY_VECTOR_FIELD, + ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER + ); PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD); PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD); PARSER.declareFieldArray( @@ -95,12 +90,20 @@ public static KnnVectorQueryBuilder fromXContent(XContentParser parser) { } private final String fieldName; - private final float[] queryVector; + private final VectorData queryVector; private Integer numCands; private final List filterQueries = new ArrayList<>(); private final Float vectorSimilarity; public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) { + this(fieldName, VectorData.fromFloats(queryVector), numCands, vectorSimilarity); + } + + public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) { + this(fieldName, VectorData.fromBytes(queryVector), numCands, vectorSimilarity); + } + + public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) { if (numCands != null && numCands > NUM_CANDS_LIMIT) { throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]"); } @@ -121,12 +124,17 @@ public KnnVectorQueryBuilder(StreamInput in) throws IOException { } else { this.numCands = in.readVInt(); } - if (in.getTransportVersion().before(TransportVersions.V_8_7_0) || in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { - this.queryVector = in.readFloatArray(); + if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) { + this.queryVector = in.readOptionalWriteable(VectorData::new); } else { - in.readBoolean(); - this.queryVector = in.readFloatArray(); - in.readBoolean(); // used for byteQueryVector, which was always null + if (in.getTransportVersion().before(TransportVersions.V_8_7_0) + || in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { + this.queryVector = VectorData.fromFloats(in.readFloatArray()); + } else { + in.readBoolean(); + this.queryVector = VectorData.fromFloats(in.readFloatArray()); + in.readBoolean(); // used for byteQueryVector, which was always null + } } if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) { this.filterQueries.addAll(readQueries(in)); @@ -143,7 +151,7 @@ public String getFieldName() { } @Nullable - public float[] queryVector() { + public VectorData queryVector() { return queryVector; } @@ -190,13 +198,17 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeVInt(numCands); } } - if (out.getTransportVersion().before(TransportVersions.V_8_7_0) - || out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { - out.writeFloatArray(queryVector); + if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) { + out.writeOptionalWriteable(queryVector); } else { - out.writeBoolean(true); - out.writeFloatArray(queryVector); - out.writeBoolean(false); // used for byteQueryVector, which was always null + if (out.getTransportVersion().before(TransportVersions.V_8_7_0) + || out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { + out.writeFloatArray(queryVector.asFloatVector()); + } else { + out.writeBoolean(true); + out.writeFloatArray(queryVector.asFloatVector()); + out.writeBoolean(false); // used for byteQueryVector, which was always null + } } if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) { writeQueries(out, filterQueries); @@ -326,13 +338,13 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { @Override protected int doHashCode() { - return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands, filterQueries, vectorSimilarity); + return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity); } @Override protected boolean doEquals(KnnVectorQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) - && Arrays.equals(queryVector, other.queryVector) + && Objects.equals(queryVector, other.queryVector) && Objects.equals(numCands, other.numCands) && Objects.equals(filterQueries, other.filterQueries) && Objects.equals(vectorSimilarity, other.vectorSimilarity); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java b/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java new file mode 100644 index 0000000000000..a92644af1fcf5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/vectors/VectorData.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HexFormat; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.Strings.format; + +public record VectorData(float[] floatVector, byte[] byteVector) implements Writeable, ToXContentFragment { + + private VectorData(float[] floatVector) { + this(floatVector, null); + } + + private VectorData(byte[] byteVector) { + this(null, byteVector); + } + + public VectorData(StreamInput in) throws IOException { + this(in.readOptionalFloatArray(), in.readOptionalByteArray()); + } + + public VectorData { + if (false == (floatVector == null ^ byteVector == null)) { + throw new IllegalArgumentException("please supply exactly either a float or a byte vector"); + } + } + + public byte[] asByteVector() { + if (byteVector != null) { + return byteVector; + } + DenseVectorFieldMapper.ElementType.BYTE.checkVectorBounds(floatVector); + byte[] vec = new byte[floatVector.length]; + for (int i = 0; i < floatVector.length; i++) { + vec[i] = (byte) floatVector[i]; + } + return vec; + } + + public float[] asFloatVector() { + if (floatVector != null) { + return floatVector; + } + float[] vec = new float[byteVector.length]; + for (int i = 0; i < byteVector.length; i++) { + vec[i] = byteVector[i]; + } + return vec; + } + + public void addToBuffer(ByteBuffer byteBuffer) { + if (floatVector != null) { + for (float val : floatVector) { + byteBuffer.putFloat(val); + } + } else { + byteBuffer.put(byteVector); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalFloatArray(floatVector); + out.writeOptionalByteArray(byteVector); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (floatVector != null) { + builder.startArray(); + for (float v : floatVector) { + builder.value(v); + } + builder.endArray(); + } else { + builder.value(HexFormat.of().formatHex(byteVector)); + } + return builder; + } + + @Override + public String toString() { + return floatVector != null ? Arrays.toString(floatVector) : Arrays.toString(byteVector); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + VectorData other = (VectorData) obj; + return Arrays.equals(floatVector, other.floatVector) && Arrays.equals(byteVector, other.byteVector); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(floatVector), Arrays.hashCode(byteVector)); + } + + public static VectorData parseXContent(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + return switch (token) { + case START_ARRAY -> parseQueryVectorArray(parser); + case VALUE_STRING -> parseHexEncodedVector(parser); + case VALUE_NUMBER -> parseNumberVector(parser); + default -> throw new ParsingException(parser.getTokenLocation(), format("Unknown type [%s] for parsing vector", token)); + }; + } + + private static VectorData parseQueryVectorArray(XContentParser parser) throws IOException { + XContentParser.Token token; + List vectorArr = new ArrayList<>(); + while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) { + if (token == XContentParser.Token.VALUE_NUMBER || token == XContentParser.Token.VALUE_STRING) { + vectorArr.add(parser.floatValue()); + } else { + throw new ParsingException(parser.getTokenLocation(), format("Type [%s] not supported for query vector", token)); + } + } + float[] floatVector = new float[vectorArr.size()]; + for (int i = 0; i < vectorArr.size(); i++) { + floatVector[i] = vectorArr.get(i); + } + return VectorData.fromFloats(floatVector); + } + + private static VectorData parseHexEncodedVector(XContentParser parser) throws IOException { + return VectorData.fromBytes(HexFormat.of().parseHex(parser.text())); + } + + private static VectorData parseNumberVector(XContentParser parser) throws IOException { + return VectorData.fromFloats(new float[] { parser.floatValue() }); + } + + public static VectorData fromFloats(float[] vec) { + return vec == null ? null : new VectorData(vec); + } + + public static VectorData fromBytes(byte[] vec) { + return vec == null ? null : new VectorData(vec); + } + +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java index c3d2d6a3f194b..27adc72fb5ed8 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; +import org.elasticsearch.search.vectors.VectorData; import java.io.IOException; import java.util.Collections; @@ -179,7 +180,7 @@ public void testExactKnnQuery() { for (int i = 0; i < dims; i++) { queryVector[i] = randomFloat(); } - Query query = field.createExactKnnQuery(queryVector); + Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector)); assertTrue(query instanceof BooleanQuery); BooleanQuery booleanQuery = (BooleanQuery) query; boolean foundFunction = false; @@ -202,12 +203,10 @@ public void testExactKnnQuery() { Collections.emptyMap() ); byte[] queryVector = new byte[dims]; - float[] floatQueryVector = new float[dims]; for (int i = 0; i < dims; i++) { queryVector[i] = randomByte(); - floatQueryVector[i] = queryVector[i]; } - Query query = field.createExactKnnQuery(floatQueryVector); + Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector)); assertTrue(query instanceof BooleanQuery); BooleanQuery booleanQuery = (BooleanQuery) query; boolean foundFunction = false; diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index ad9c95b5b80c7..45ad9d514ba82 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -43,10 +43,12 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase { private static final String VECTOR_FIELD = "vector"; private static final String VECTOR_ALIAS_FIELD = "vector_alias"; - private static final int VECTOR_DIMENSION = 3; + static final int VECTOR_DIMENSION = 3; abstract DenseVectorFieldMapper.ElementType elementType(); + abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity); + @Override protected void initializeAdditionalMappings(MapperService mapperService) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -75,12 +77,9 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws @Override protected KnnVectorQueryBuilder doCreateTestQueryBuilder() { String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD; - float[] vector = new float[VECTOR_DIMENSION]; - for (int i = 0; i < vector.length; i++) { - vector[i] = elementType().equals(DenseVectorFieldMapper.ElementType.BYTE) ? randomByte() : randomFloat(); - } int numCands = randomIntBetween(DEFAULT_SIZE, 1000); - KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(fieldName, vector, numCands, randomBoolean() ? null : randomFloat()); + KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, numCands, randomBoolean() ? null : randomFloat()); + if (randomBoolean()) { List filters = new ArrayList<>(); int numFilters = randomIntBetween(1, 5); @@ -120,11 +119,16 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que Query knnVectorQueryBuilt = switch (elementType()) { case BYTE -> new ESKnnByteVectorQuery( VECTOR_FIELD, - getByteQueryVector(queryBuilder.queryVector()), + queryBuilder.queryVector().asByteVector(), + queryBuilder.numCands(), + filterQuery + ); + case FLOAT -> new ESKnnFloatVectorQuery( + VECTOR_FIELD, + queryBuilder.queryVector().asFloatVector(), queryBuilder.numCands(), filterQuery ); - case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector(), queryBuilder.numCands(), filterQuery); }; if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) { query = vectorSimilarityQuery.getInnerKnnQuery(); @@ -193,7 +197,8 @@ public void testMustRewrite() throws IOException { public void testBWCVersionSerializationFilters() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); - KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), query.queryVector(), query.numCands(), null) + VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); + KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), null) .queryName(query.queryName()) .boost(query.boost()); TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween( @@ -206,12 +211,11 @@ public void testBWCVersionSerializationFilters() throws IOException { public void testBWCVersionSerializationSimilarity() throws IOException { KnnVectorQueryBuilder query = createTestQueryBuilder(); - KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder( - query.getFieldName(), - query.queryVector(), - query.numCands(), - null - ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); + VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); + KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), null) + .queryName(query.queryName()) + .boost(query.boost()) + .addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0); } @@ -223,12 +227,11 @@ public void testBWCVersionSerializationQuery() throws IOException { TransportVersions.V_8_12_0 ); Float similarity = differentQueryVersion.before(TransportVersions.V_8_8_0) ? null : query.getVectorSimilarity(); - KnnVectorQueryBuilder queryOlderVersion = new KnnVectorQueryBuilder( - query.getFieldName(), - query.queryVector(), - query.numCands(), - similarity - ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries()); + VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector()); + KnnVectorQueryBuilder queryOlderVersion = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), similarity) + .queryName(query.queryName()) + .boost(query.boost()) + .addFilterQueries(query.filterQueries()); assertBWCSerialization(query, queryOlderVersion, differentQueryVersion); } @@ -245,12 +248,4 @@ private void assertBWCSerialization(QueryBuilder newQuery, QueryBuilder bwcQuery } } } - - private static byte[] getByteQueryVector(float[] queryVector) { - byte[] byteQueryVector = new byte[queryVector.length]; - for (int i = 0; i < queryVector.length; i++) { - byteQueryVector[i] = (byte) queryVector[i]; - } - return byteQueryVector; - } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java index 3bf92a60275d8..6c83700d0b29a 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java @@ -15,4 +15,13 @@ public class KnnByteVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilde DenseVectorFieldMapper.ElementType elementType() { return DenseVectorFieldMapper.ElementType.BYTE; } + + @Override + protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity) { + byte[] vector = new byte[VECTOR_DIMENSION]; + for (int i = 0; i < vector.length; i++) { + vector[i] = randomByte(); + } + return new KnnVectorQueryBuilder(fieldName, vector, numCands, similarity); + } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java similarity index 56% rename from server/src/test/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilderTests.java rename to server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java index 6b1166cdd16dc..eeb5244d57943 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java @@ -10,9 +10,18 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -public class KnnVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilderTestCase { +public class KnnFloatVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilderTestCase { @Override DenseVectorFieldMapper.ElementType elementType() { return DenseVectorFieldMapper.ElementType.FLOAT; } + + @Override + KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity) { + float[] vector = new float[VECTOR_DIMENSION]; + for (int i = 0; i < vector.length; i++) { + vector[i] = randomFloat(); + } + return new KnnVectorQueryBuilder(fieldName, vector, numCands, similarity); + } } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java index c650f54321060..564c8b9d0db11 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java @@ -106,7 +106,7 @@ protected KnnSearchBuilder mutateInstance(KnnSearchBuilder instance) { instance.boost ); case 1: - float[] newVector = randomValueOtherThan(instance.queryVector, () -> randomVector(5)); + float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5)); return new KnnSearchBuilder(instance.field, newVector, instance.k, instance.numCands, instance.similarity).boost( instance.boost ); @@ -213,7 +213,7 @@ public void testRewrite() throws Exception { assertThat(rewritten.field, equalTo(searchBuilder.field)); assertThat(rewritten.boost, equalTo(searchBuilder.boost)); - assertThat(rewritten.queryVector, equalTo(expectedArray)); + assertThat(rewritten.queryVector.asFloatVector(), equalTo(expectedArray)); assertThat(rewritten.queryVectorBuilder, nullValue()); assertThat(rewritten.filterQueries, hasSize(1)); assertThat(rewritten.similarity, equalTo(1f)); diff --git a/server/src/test/java/org/elasticsearch/search/vectors/VectorDataTests.java b/server/src/test/java/org/elasticsearch/search/vectors/VectorDataTests.java new file mode 100644 index 0000000000000..feabc100d1007 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/vectors/VectorDataTests.java @@ -0,0 +1,199 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.vectors; + +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; + +import static org.hamcrest.Matchers.containsString; + +public class VectorDataTests extends ESTestCase { + + private static final float DELTA = 1e-5f; + + public void testThrowsIfBothVectorsAreNull() { + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> new VectorData(null, null)); + assertThat(ex.getMessage(), containsString("please supply exactly either a float or a byte vector")); + } + + public void testThrowsIfBothVectorsAreNonNull() { + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> new VectorData(new float[] { 0f }, new byte[] { 1 }) + ); + assertThat(ex.getMessage(), containsString("please supply exactly either a float or a byte vector")); + } + + public void testShouldCorrectlyConvertByteToFloatIfExplicitlyRequested() { + byte[] byteVector = new byte[] { 1, 2, -127 }; + float[] expected = new float[] { 1f, 2f, -127f }; + + VectorData vectorData = new VectorData(null, byteVector); + float[] actual = vectorData.asFloatVector(); + assertArrayEquals(expected, actual, DELTA); + } + + public void testShouldThrowForDecimalsWhenConvertingToByte() { + float[] vec = new float[] { 1f, 2f, 3.1f }; + + VectorData vectorData = new VectorData(vec, null); + expectThrows(IllegalArgumentException.class, vectorData::asByteVector); + } + + public void testShouldThrowForOutsideRangeWhenConvertingToByte() { + float[] vec = new float[] { 1f, 2f, 200f }; + + VectorData vectorData = new VectorData(vec, null); + expectThrows(IllegalArgumentException.class, vectorData::asByteVector); + } + + public void testEqualsAndHashCode() { + VectorData v1 = new VectorData(new float[] { 1, 2, 3 }, null); + VectorData v2 = new VectorData(null, new byte[] { 1, 2, 3 }); + assertNotEquals(v1, v2); + assertNotEquals(v1.hashCode(), v2.hashCode()); + + VectorData v3 = new VectorData(null, new byte[] { 1, 2, 3 }); + assertEquals(v2, v3); + assertEquals(v2.hashCode(), v3.hashCode()); + } + + public void testParseHexCorrectly() throws IOException { + byte[] expected = new byte[] { 64, 10, -30, 10 }; + String toParse = "\"400ae20a\""; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + VectorData parsed = VectorData.parseXContent(parser); + assertArrayEquals(expected, parsed.asByteVector()); + } + } + + public void testParseFloatArray() throws IOException { + float[] expected = new float[] { 1f, -1f, .1f }; + String toParse = "[1.0, -1.0, 0.1]"; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + VectorData parsed = VectorData.parseXContent(parser); + assertArrayEquals(expected, parsed.asFloatVector(), DELTA); + } + } + + public void testParseByteArray() throws IOException { + byte[] expected = new byte[] { 64, 10, -30, 10 }; + String toParse = "[64,10,-30,10]"; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + VectorData parsed = VectorData.parseXContent(parser); + assertArrayEquals(expected, parsed.asByteVector()); + } + } + + public void testByteThrowsForOutsideRange() throws IOException { + String toParse = "[1000]"; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + VectorData parsed = VectorData.parseXContent(parser); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, parsed::asByteVector); + assertThat(ex.getMessage(), containsString("vectors only support integers between [-128, 127]")); + } + } + + public void testAsByteThrowsForDecimals() throws IOException { + String toParse = "[0.1]"; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + VectorData parsed = VectorData.parseXContent(parser); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, parsed::asByteVector); + assertThat(ex.getMessage(), containsString("vectors only support non-decimal values but found decimal value")); + } + } + + public void testParseSingleNumber() throws IOException { + float[] expected = new float[] { 0.1f }; + String toParse = "0.1"; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + VectorData parsed = VectorData.parseXContent(parser); + assertArrayEquals(expected, parsed.asFloatVector(), DELTA); + } + } + + public void testParseThrowsForUnknown() throws IOException { + String unknown = "{\"foo\":\"bar\"}"; + try ( + XContentParser parser = XContentHelper.createParser( + XContentParserConfiguration.EMPTY, + new BytesArray(unknown), + XContentType.JSON + ) + ) { + parser.nextToken(); + ParsingException ex = expectThrows(ParsingException.class, () -> VectorData.parseXContent(parser)); + assertThat(ex.getMessage(), containsString("Unknown type [" + XContentParser.Token.START_OBJECT + "] for parsing vector")); + } + } + + public void testFailForUnknownArrayValue() throws IOException { + String toParse = "[0.1, true]"; + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + XContentParserConfiguration.EMPTY, + new BytesArray(toParse), + XContentType.JSON + ) + ) { + parser.nextToken(); + ParsingException ex = expectThrows(ParsingException.class, () -> VectorData.parseXContent(parser)); + assertThat(ex.getMessage(), containsString("Type [" + XContentParser.Token.VALUE_BOOLEAN + "] not supported for query vector")); + } + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java index b6e5c7161edc8..b327aee0931f9 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java @@ -132,7 +132,7 @@ public final void testKnnSearchRewrite() throws Exception { PlainActionFuture future = new PlainActionFuture<>(); Rewriteable.rewriteAndFetch(randomFrom(serialized, searchBuilder), context, future); KnnSearchBuilder rewritten = future.get(); - assertThat(rewritten.getQueryVector(), equalTo(expected)); + assertThat(rewritten.getQueryVector().asFloatVector(), equalTo(expected)); assertThat(rewritten.getQueryVectorBuilder(), nullValue()); } }