Skip to content

Commit 523728d

Browse files
committed
ClassificationAnalyzer
1 parent e4c8b5c commit 523728d

File tree

4 files changed

+180
-1
lines changed

4 files changed

+180
-1
lines changed

src/main/java/com/arangodb/entity/arangosearch/AnalyzerType.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,18 @@
2424
* @author Michele Rastelli
2525
*/
2626
public enum AnalyzerType {
27-
identity, delimiter, stem, norm, ngram, text, pipeline, stopwords, aql, geojson, geopoint, segmentation, collation
27+
identity,
28+
delimiter,
29+
stem,
30+
norm,
31+
ngram,
32+
text,
33+
pipeline,
34+
stopwords,
35+
aql,
36+
geojson,
37+
geopoint,
38+
segmentation,
39+
collation,
40+
classification
2841
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* DISCLAIMER
3+
*
4+
* Copyright 2016 ArangoDB GmbH, Cologne, Germany
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*
18+
* Copyright holder is ArangoDB GmbH, Cologne, Germany
19+
*/
20+
21+
package com.arangodb.entity.arangosearch.analyzer;
22+
23+
24+
import com.arangodb.entity.arangosearch.AnalyzerType;
25+
26+
import java.util.Objects;
27+
28+
/**
29+
* An Analyzer capable of classifying tokens in the input text. It applies a user-provided supervised fastText word
30+
* embedding model to classify the input text. It is able to classify individual tokens as well as entire inputs.
31+
*
32+
* @author Michele Rastelli
33+
* @see <a href= "https://www.arangodb.com/docs/stable/analyzers.html#classification">API Documentation</a>
34+
* @since ArangoDB 3.10
35+
*/
36+
public class ClassificationAnalyzer extends SearchAnalyzer {
37+
public ClassificationAnalyzer() {
38+
setType(AnalyzerType.classification);
39+
}
40+
41+
private ClassificationAnalyzerProperties properties;
42+
43+
public ClassificationAnalyzerProperties getProperties() {
44+
return properties;
45+
}
46+
47+
public void setProperties(ClassificationAnalyzerProperties properties) {
48+
this.properties = properties;
49+
}
50+
51+
@Override
52+
public boolean equals(Object o) {
53+
if (this == o) return true;
54+
if (o == null || getClass() != o.getClass()) return false;
55+
if (!super.equals(o)) return false;
56+
ClassificationAnalyzer that = (ClassificationAnalyzer) o;
57+
return Objects.equals(properties, that.properties);
58+
}
59+
60+
@Override
61+
public int hashCode() {
62+
return Objects.hash(super.hashCode(), properties);
63+
}
64+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* DISCLAIMER
3+
*
4+
* Copyright 2016 ArangoDB GmbH, Cologne, Germany
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*
18+
* Copyright holder is ArangoDB GmbH, Cologne, Germany
19+
*/
20+
21+
package com.arangodb.entity.arangosearch.analyzer;
22+
23+
24+
import com.arangodb.velocypack.annotations.SerializedName;
25+
26+
import java.util.Objects;
27+
28+
/**
29+
* @author Michele Rastelli
30+
* @since ArangoDB 3.10
31+
*/
32+
public class ClassificationAnalyzerProperties {
33+
34+
@SerializedName("model_location")
35+
private String modelLocation;
36+
37+
@SerializedName("top_k")
38+
private Integer topK;
39+
40+
private Double threshold;
41+
42+
public String getModelLocation() {
43+
return modelLocation;
44+
}
45+
46+
public void setModelLocation(String modelLocation) {
47+
this.modelLocation = modelLocation;
48+
}
49+
50+
public Integer getTopK() {
51+
return topK;
52+
}
53+
54+
public void setTopK(Integer topK) {
55+
this.topK = topK;
56+
}
57+
58+
public Double getThreshold() {
59+
return threshold;
60+
}
61+
62+
public void setThreshold(Double threshold) {
63+
this.threshold = threshold;
64+
}
65+
66+
@Override
67+
public boolean equals(Object o) {
68+
if (this == o) return true;
69+
if (o == null || getClass() != o.getClass()) return false;
70+
ClassificationAnalyzerProperties that = (ClassificationAnalyzerProperties) o;
71+
return Objects.equals(modelLocation, that.modelLocation) && Objects.equals(topK, that.topK) && Objects.equals(threshold, that.threshold);
72+
}
73+
74+
@Override
75+
public int hashCode() {
76+
return Objects.hash(modelLocation, topK, threshold);
77+
}
78+
}

src/test/java/com/arangodb/ArangoSearchTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,30 @@ void collationAnalyzer(ArangoDatabase db) {
987987
createGetAndDeleteTypedAnalyzer(db, collationAnalyzer);
988988
}
989989

990+
@ParameterizedTest(name = "{index}")
991+
@MethodSource("dbs")
992+
void classificationAnalyzer(ArangoDatabase db) {
993+
assumeTrue(isAtLeastVersion(3, 10));
994+
assumeTrue(isEnterprise());
995+
996+
ClassificationAnalyzerProperties properties = new ClassificationAnalyzerProperties();
997+
properties.setModelLocation("/foo/bar");
998+
properties.setTopK(2);
999+
properties.setThreshold(.5);
1000+
1001+
Set<AnalyzerFeature> features = new HashSet<>();
1002+
features.add(AnalyzerFeature.frequency);
1003+
features.add(AnalyzerFeature.norm);
1004+
features.add(AnalyzerFeature.position);
1005+
1006+
ClassificationAnalyzer analyzer = new ClassificationAnalyzer();
1007+
analyzer.setName("test-" + UUID.randomUUID());
1008+
analyzer.setProperties(properties);
1009+
analyzer.setFeatures(features);
1010+
1011+
createGetAndDeleteTypedAnalyzer(db, analyzer);
1012+
}
1013+
9901014
@ParameterizedTest(name = "{index}")
9911015
@MethodSource("dbs")
9921016
void offsetFeature(ArangoDatabase db) {

0 commit comments

Comments
 (0)