Scoring plugin, unexpected behavior

Hey. I am writing a plugin for calculating the distance of the EMD vector for Elastic 7.10.0. I took the repository as an example https://github.com/elastic/elasticsearch/tree/master/plugins/examples/script-expert-scoring.
My code:

    package com.liorkn.elasticsearch.plugin;

    import java.io.UncheckedIOException;
    import java.io.IOException;

    import java.util.Collections;
    import java.util.Map;
    import java.util.Set;
    import java.util.Collection;
    import java.util.ArrayList;
    import java.util.List;

    import org.elasticsearch.common.settings.Settings;
    import org.elasticsearch.plugins.Plugin;
    import org.elasticsearch.plugins.ScriptPlugin;
    import org.elasticsearch.script.ScriptContext;
    import org.elasticsearch.script.ScriptEngine;
    import org.elasticsearch.script.ScoreScript.LeafFactory;
    import org.elasticsearch.script.ScoreScript;
    import org.elasticsearch.search.lookup.SearchLookup;

    import org.apache.lucene.index.LeafReaderContext;
    import org.apache.lucene.index.BinaryDocValues;
    import org.apache.lucene.store.ByteArrayDataInput;

    import java.security.AccessController;
    import java.security.PrivilegedAction;

    import org.opencv.imgproc.Imgproc;
    import org.opencv.core.Mat;
    import org.opencv.core.CvType;
    import nu.pattern.*;

    /**
     * This class is instantiated when Elasticsearch loads the plugin for the first
     * time. If you change the name of this plugin, make sure to update
     * src/main/resources/es-plugin.properties file that points to this class.
     */
    public class VectorScoringPlugin extends Plugin implements ScriptPlugin {
        static {
            AccessController.doPrivileged((PrivilegedAction<Void>) () -> {
                // Privileged code goes here, for example:
                OpenCV.loadLocally();
                return null; // nothing to return
            });
        }

        @Override
        public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
            return new VectorScoringScriptEngine();
        }

        private static class VectorScoringScriptEngine implements ScriptEngine {
            public static final String NAME = "knn";
            private static final String SCRIPT_SOURCE = "vector_score";

            @Override
            public String getType() {
                return NAME;
            }

            @Override
            public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context,
                    Map<String, String> params) {

                if (!context.equals(ScoreScript.CONTEXT)) {
                    throw new IllegalArgumentException(
                            getType() + " scripts cannot be used for context [" + context.name + "]");

                }

                // we use the script "source" as the script identifier
                if (!SCRIPT_SOURCE.equals(scriptSource)) {
                    throw new IllegalArgumentException("Unknown script name " + scriptSource);
                }

                ScoreScript.Factory factory = new VectorScoreScriptFactory();
                return context.factoryClazz.cast(factory);
            }

            @Override
            public void close() throws IOException {
                // Methods should not be empty.
            }

            @Override
            public Set<ScriptContext<?>> getSupportedContexts() {
                return Collections.singleton(ScoreScript.CONTEXT);
            }

            private static class VectorScoreScriptFactory implements ScoreScript.Factory {

                @Override
                public boolean isResultDeterministic() {
                    // PureDfLeafFactory only uses deterministic APIs, this
                    // implies the results are cacheable.
                    return true;
                }

                @Override
                public LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
                    return new VectorScoreScriptLeafFactory(params, lookup);
                }
            }

            private static class VectorScoreScriptLeafFactory implements LeafFactory {
                private static final String PARAM_COLOR_VECTOR = "color";
                private static final String DOC_COLOR_VECTOR = "colorVector";
                private final Map<String, Object> params;
                private final SearchLookup lookup;
                private final Mat signatureInput;

                private VectorScoreScriptLeafFactory(Map<String, Object> params, SearchLookup lookup) {
                    if (!params.containsKey(PARAM_COLOR_VECTOR)) {
                        throw new IllegalArgumentException("Must have at '" + PARAM_COLOR_VECTOR + "' as a parameter");
                    }

                    this.params = params;
                    this.lookup = lookup;

                    // get query inputVector - convert to primitive
                    final Object vector = params.get(PARAM_COLOR_VECTOR);
                    final ArrayList<Double> tmp = (ArrayList<Double>) vector;

                    float[] inputVector = new float[tmp.size()];
                    for (int i = 0; i < inputVector.length; i++) {
                        inputVector[i] = tmp.get(i).floatValue();
                    }

                    if (inputVector.length % 4 != 0) {
                        throw new IllegalArgumentException("Input vector must be a multiple of 4");
                    }

                    int rowsInputVector = inputVector.length / 4;
                    this.signatureInput = new Mat(rowsInputVector, 4, CvType.CV_32F);
                    this.signatureInput.put(0, 0, inputVector);
                    this.signatureInput.reshape(4, rowsInputVector);
                }

                @Override
                public boolean needs_score() {
                    return false;
                }

                @Override
                public ScoreScript newInstance(LeafReaderContext ctx) throws IOException {
                    BinaryDocValues binaryEmbeddingReader = ctx.reader().getBinaryDocValues(DOC_COLOR_VECTOR);
                    if (binaryEmbeddingReader == null) {
                        return new ScoreScript(params, lookup, ctx) {

                            @Override
                            public double execute(ExplanationHolder explanation) {
                                return 0.0d;
                            }
                        };
                    }

                    return new ScoreScript(params, lookup, ctx) {
                        private int currentDocid = -1;

                        @Override
                        public void setDocument(int docId) {
                            if (binaryEmbeddingReader.docID() < docId) {
                                try {
                                    binaryEmbeddingReader.advance(docId);
                                } catch (IOException e) {
                                    throw new UncheckedIOException(e);
                                }
                            }
                            currentDocid = docId;
                        }

                        @Override
                        public double execute(ExplanationHolder explanationHolder) {
                            if (binaryEmbeddingReader.docID() != currentDocid) {
                                /*
                                 * advance moved past the current doc, so this doc has no occurrences of the
                                 * term
                                 */
                                return 0.0d;
                            }

                            try {
                                final byte[] bytes = binaryEmbeddingReader.binaryValue().bytes;
                                final ByteArrayDataInput input = new ByteArrayDataInput(bytes);
                                List<Float> listFloatVector = new ArrayList<>();

                                while (!input.eof())
                                    listFloatVector.add(Float.intBitsToFloat(input.readInt()));

                                float[] l2 = new float[listFloatVector.size()];
                                int i = 0;

                                for (Float f : listFloatVector)
                                    l2[i++] = f;

                                int rowsDocVector = l2.length / 4;
                                Mat signature2 = new Mat(rowsDocVector, 4, CvType.CV_32F);

                                signature2.put(0, 0, l2);
                                signature2.reshape(4, rowsDocVector);

                                float distEMD = Imgproc.EMD(signatureInput, signature2, Imgproc.DIST_L2);

                                if (distEMD == 0)
                                    return 9999.9999;

                                return 1 / distEMD;

                            } catch (IOException e) {
                                throw new UncheckedIOException(e); // again - Failing in order not to hide potential bugs
                            }
                        }
                    };
                }
            }
        }
    }

The index has a colorVector field of type dense_vector. dims 24. But the following problem arises: I add 4 records to the index. Then I search for the view

    {
        "query": {
            "function_score": {
                "query": {
                    "match_all": {}
                },
                "functions": [
                    {
                        "script_score": {
                            "script": {
                                "source": "vector_score",
                                "lang": "knn",
                                "params": {
                                    "color": [
                                        0.052521251142024994,
                                        246.47422790527344,
                                        135.3640899658203,
                                        143.88287353515625,
                                        0.7877722382545471,
                                        247.39590454101562,
                                        247.3125457763672,
                                        247.4180908203125,
                                        0.05891700088977814,
                                        237.54730224609375,
                                        199.69053649902344,
                                        182.0277099609375,
                                        0.05097400024533272,
                                        217.9976348876953,
                                        109.98776245117188,
                                        109.67768096923828,
                                        0.034476250410079956,
                                        222.72132873535156,
                                        164.4148406982422,
                                        140.03741455078125,
                                        0.01533924974501133,
                                        180.0423126220703,
                                        65.70858764648438,
                                        67.94438171386719
                                    ]
                                }
                            }
                        }
                    }
                ]
            }
        }
    }

I get expected results without errors. Then I add 2 more records to the index. Making a search request again and I get the following result:

. The newly added records have the same distance, although the index contains different vectors in the colorVector field. How can you fix this?

This topic was automatically closed 28 days after the last reply. New replies are no longer allowed.