Hey. I am writing a plugin for calculating the distance of the EMD vector for Elastic 7.10.0. I took the repository as an example https://github.com/elastic/elasticsearch/tree/master/plugins/examples/script-expert-scoring.
My code:
package com.liorkn.elasticsearch.plugin;
import java.io.UncheckedIOException;
import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.Collection;
import java.util.ArrayList;
import java.util.List;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.ScoreScript.LeafFactory;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.search.lookup.SearchLookup;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.store.ByteArrayDataInput;
import java.security.AccessController;
import java.security.PrivilegedAction;
import org.opencv.imgproc.Imgproc;
import org.opencv.core.Mat;
import org.opencv.core.CvType;
import nu.pattern.*;
/**
* This class is instantiated when Elasticsearch loads the plugin for the first
* time. If you change the name of this plugin, make sure to update
* src/main/resources/es-plugin.properties file that points to this class.
*/
public class VectorScoringPlugin extends Plugin implements ScriptPlugin {
static {
AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
// Privileged code goes here, for example:
OpenCV.loadLocally();
return null; // nothing to return
});
}
@Override
public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
return new VectorScoringScriptEngine();
}
private static class VectorScoringScriptEngine implements ScriptEngine {
public static final String NAME = "knn";
private static final String SCRIPT_SOURCE = "vector_score";
@Override
public String getType() {
return NAME;
}
@Override
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context,
Map<String, String> params) {
if (!context.equals(ScoreScript.CONTEXT)) {
throw new IllegalArgumentException(
getType() + " scripts cannot be used for context [" + context.name + "]");
}
// we use the script "source" as the script identifier
if (!SCRIPT_SOURCE.equals(scriptSource)) {
throw new IllegalArgumentException("Unknown script name " + scriptSource);
}
ScoreScript.Factory factory = new VectorScoreScriptFactory();
return context.factoryClazz.cast(factory);
}
@Override
public void close() throws IOException {
// Methods should not be empty.
}
@Override
public Set<ScriptContext<?>> getSupportedContexts() {
return Collections.singleton(ScoreScript.CONTEXT);
}
private static class VectorScoreScriptFactory implements ScoreScript.Factory {
@Override
public boolean isResultDeterministic() {
// PureDfLeafFactory only uses deterministic APIs, this
// implies the results are cacheable.
return true;
}
@Override
public LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
return new VectorScoreScriptLeafFactory(params, lookup);
}
}
private static class VectorScoreScriptLeafFactory implements LeafFactory {
private static final String PARAM_COLOR_VECTOR = "color";
private static final String DOC_COLOR_VECTOR = "colorVector";
private final Map<String, Object> params;
private final SearchLookup lookup;
private final Mat signatureInput;
private VectorScoreScriptLeafFactory(Map<String, Object> params, SearchLookup lookup) {
if (!params.containsKey(PARAM_COLOR_VECTOR)) {
throw new IllegalArgumentException("Must have at '" + PARAM_COLOR_VECTOR + "' as a parameter");
}
this.params = params;
this.lookup = lookup;
// get query inputVector - convert to primitive
final Object vector = params.get(PARAM_COLOR_VECTOR);
final ArrayList<Double> tmp = (ArrayList<Double>) vector;
float[] inputVector = new float[tmp.size()];
for (int i = 0; i < inputVector.length; i++) {
inputVector[i] = tmp.get(i).floatValue();
}
if (inputVector.length % 4 != 0) {
throw new IllegalArgumentException("Input vector must be a multiple of 4");
}
int rowsInputVector = inputVector.length / 4;
this.signatureInput = new Mat(rowsInputVector, 4, CvType.CV_32F);
this.signatureInput.put(0, 0, inputVector);
this.signatureInput.reshape(4, rowsInputVector);
}
@Override
public boolean needs_score() {
return false;
}
@Override
public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
BinaryDocValues binaryEmbeddingReader = ctx.reader().getBinaryDocValues(DOC_COLOR_VECTOR);
if (binaryEmbeddingReader == null) {
return new ScoreScript(params, lookup, ctx) {
@Override
public double execute(ExplanationHolder explanation) {
return 0.0d;
}
};
}
return new ScoreScript(params, lookup, ctx) {
private int currentDocid = -1;
@Override
public void setDocument(int docId) {
if (binaryEmbeddingReader.docID() < docId) {
try {
binaryEmbeddingReader.advance(docId);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
currentDocid = docId;
}
@Override
public double execute(ExplanationHolder explanationHolder) {
if (binaryEmbeddingReader.docID() != currentDocid) {
/*
* advance moved past the current doc, so this doc has no occurrences of the
* term
*/
return 0.0d;
}
try {
final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes;
final ByteArrayDataInput input = new ByteArrayDataInput(bytes);
List<Float> listFloatVector = new ArrayList<>();
while (!input.eof())
listFloatVector.add(Float.intBitsToFloat(input.readInt()));
float[] l2 = new float[listFloatVector.size()];
int i = 0;
for (Float f : listFloatVector)
l2[i++] = f;
int rowsDocVector = l2.length / 4;
Mat signature2 = new Mat(rowsDocVector, 4, CvType.CV_32F);
signature2.put(0, 0, l2);
signature2.reshape(4, rowsDocVector);
float distEMD = Imgproc.EMD(signatureInput, signature2, Imgproc.DIST_L2);
if (distEMD == 0)
return 9999.9999;
return 1 / distEMD;
} catch (IOException e) {
throw new UncheckedIOException(e); // again - Failing in order not to hide potential bugs
}
}
};
}
}
}
}
The index has a colorVector field of type dense_vector. dims 24. But the following problem arises: I add 4 records to the index. Then I search for the view
{
"query": {
"function_score": {
"query": {
"match_all": {}
},
"functions": [
{
"script_score": {
"script": {
"source": "vector_score",
"lang": "knn",
"params": {
"color": [
0.052521251142024994,
246.47422790527344,
135.3640899658203,
143.88287353515625,
0.7877722382545471,
247.39590454101562,
247.3125457763672,
247.4180908203125,
0.05891700088977814,
237.54730224609375,
199.69053649902344,
182.0277099609375,
0.05097400024533272,
217.9976348876953,
109.98776245117188,
109.67768096923828,
0.034476250410079956,
222.72132873535156,
164.4148406982422,
140.03741455078125,
0.01533924974501133,
180.0423126220703,
65.70858764648438,
67.94438171386719
]
}
}
}
}
]
}
}
}
I get expected results without errors. Then I add 2 more records to the index. Making a search request again and I get the following result:
. The newly added records have the same distance, although the index contains different vectors in the colorVector field. How can you fix this?