Skip to content

Commit 3833517

Browse files
committed
addressing PR comments - removing duplicated code and opting for switch statement
1 parent 0e4d7e6 commit 3833517

File tree

2 files changed

+16
-60
lines changed

2 files changed

+16
-60
lines changed

server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import org.apache.lucene.util.SetOnce;
1212
import org.elasticsearch.TransportVersions;
13-
import org.elasticsearch.common.ParsingException;
1413
import org.elasticsearch.common.io.stream.StreamInput;
1514
import org.elasticsearch.common.io.stream.StreamOutput;
1615
import org.elasticsearch.common.io.stream.Writeable;
@@ -29,7 +28,6 @@
2928
import java.io.IOException;
3029
import java.util.ArrayList;
3130
import java.util.Arrays;
32-
import java.util.HexFormat;
3331
import java.util.List;
3432
import java.util.Objects;
3533
import java.util.function.Supplier;
@@ -38,6 +36,7 @@
3836
import static org.elasticsearch.common.Strings.format;
3937
import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST;
4038
import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
39+
import static org.elasticsearch.search.vectors.KnnVectorQueryBuilder.parseQueryVector;
4140
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
4241
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
4342

@@ -68,52 +67,6 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
6867
.similarity((Float) args[5]);
6968
});
7069

71-
private static float[] parseQueryVector(XContentParser parser) throws IOException {
72-
XContentParser.Token token = parser.currentToken();
73-
final float[] vector;
74-
if (token == XContentParser.Token.START_ARRAY) {
75-
vector = parseQueryVectorArray(parser);
76-
} else if (token == XContentParser.Token.VALUE_STRING) {
77-
vector = parseHexEncodedVector(parser);
78-
} else if (token == XContentParser.Token.VALUE_NUMBER) {
79-
vector = parseNumberVector(parser);
80-
} else {
81-
throw new ParsingException(parser.getTokenLocation(), format("Unknown type for provided value [%s]", parser.text()));
82-
}
83-
return vector;
84-
}
85-
86-
private static float[] parseQueryVectorArray(XContentParser parser) throws IOException {
87-
XContentParser.Token token;
88-
List<Float> vectorArr = new ArrayList<>();
89-
while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
90-
if (token == XContentParser.Token.VALUE_NUMBER || token == XContentParser.Token.VALUE_STRING) {
91-
vectorArr.add(parser.floatValue());
92-
} else {
93-
throw new ParsingException(parser.getTokenLocation(), format("Type [%s] not supported for query vector"));
94-
}
95-
}
96-
float[] floatVector = new float[vectorArr.size()];
97-
for (int i = 0; i < vectorArr.size(); i++) {
98-
floatVector[i] = vectorArr.get(i);
99-
}
100-
return floatVector;
101-
}
102-
103-
private static float[] parseNumberVector(XContentParser parser) throws IOException {
104-
return new float[] { parser.floatValue() };
105-
}
106-
107-
private static float[] parseHexEncodedVector(XContentParser parser) throws IOException {
108-
// TODO optimize this as the array returned will be recomputed later again as a byte array
109-
byte[] decodedByteQueryVector = HexFormat.of().parseHex(parser.text());
110-
float[] floatVector = new float[decodedByteQueryVector.length];
111-
for (int i = 0; i < decodedByteQueryVector.length; i++) {
112-
floatVector[i] = decodedByteQueryVector[i];
113-
}
114-
return floatVector;
115-
}
116-
11770
static {
11871
PARSER.declareString(constructorArg(), FIELD_FIELD);
11972
PARSER.declareField(

server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,22 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
6969
args -> new KnnVectorQueryBuilder((String) args[0], (float[]) args[1], (Integer) args[2], (Float) args[3])
7070
);
7171

72-
private static float[] parseQueryVector(XContentParser parser) throws IOException {
72+
/**
73+
* Utility method to parse the provided {query_vector} parameter. Supports the following formats:
74+
* - array of floats, as an n-dimensional vector
75+
* - single number, as a 1-dimensional vector
76+
* - string, as a hex-encoded byte vector
77+
*
78+
* @return an array of floats representing the provided query vector
79+
*/
80+
public static float[] parseQueryVector(XContentParser parser) throws IOException {
7381
XContentParser.Token token = parser.currentToken();
74-
final float[] vector;
75-
if (token == XContentParser.Token.START_ARRAY) {
76-
vector = parseQueryVectorArray(parser);
77-
} else if (token == XContentParser.Token.VALUE_STRING) {
78-
vector = parseHexEncodedVector(parser);
79-
} else if (token == XContentParser.Token.VALUE_NUMBER) {
80-
vector = parseNumberVector(parser);
81-
} else {
82-
throw new ParsingException(parser.getTokenLocation(), format("Unknown type for provided value [%s]", parser.text()));
83-
}
84-
return vector;
82+
return switch (token) {
83+
case START_ARRAY -> parseQueryVectorArray(parser);
84+
case VALUE_STRING -> parseHexEncodedVector(parser);
85+
case VALUE_NUMBER -> parseNumberVector(parser);
86+
default -> throw new ParsingException(parser.getTokenLocation(), format("Unknown type for provided value [%s]", parser.text()));
87+
};
8588
}
8689

8790
private static float[] parseQueryVectorArray(XContentParser parser) throws IOException {

0 commit comments

Comments
 (0)