/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.gosololaw.elasticsearch;

import org.apache.lucene.index.*;

import org.apache.lucene.store.ByteArrayDataInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.SearchScript;
import org.elasticsearch.search.lookup.LeafSearchLookup;

import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;

/**
 * This class is instantiated when Elasticsearch loads the plugin for the
 * first time. If you change the name of this plugin, make sure to update
 * src/main/resources/es-plugin.properties file that points to this class.
 */
public final class VectorScoringPlugin extends Plugin implements ScriptPlugin {

    private static final int DOUBLE_SIZE = 8;

    @Override
    public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
        return new VectorScoreEngine();
    }

    /** An example {@link ScriptEngine} that uses Lucene segment details to implement pure document frequency scoring. */
    // tag::expert_engine
    private static class VectorScoreEngine implements ScriptEngine {

        @Override
        public String getType() {
            return "binary_vector_score";
        }

        @Override
        @SuppressWarnings("unchecked")
        public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {
            if (!context.equals(SearchScript.CONTEXT)) {
                throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
            }
            // we use the script "source" as the script identifier
            if ("vector_scoring".equals(scriptSource)) {
                SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
                    private final double[] inputVector;
                    final String field;
                    {
                        final Object field = p.get("vector_field");
                        if (field == null)
                            throw new IllegalArgumentException("binary_vector_score script requires field vector_field");
                        this.field = field.toString();

                        // get query inputVector - convert to primitive
                        final ArrayList<Double> tmp = (ArrayList<Double>) p.get("vector");
                        this.inputVector = new double[tmp.size()];
                        for (int i = 0; i < inputVector.length; i++) {
                            inputVector[i] = tmp.get(i);
                        }
                    }

                    @Override
                    public SearchScript newInstance(LeafReaderContext context) throws IOException {

                        return new SearchScript(p, lookup, context) {
                            BinaryDocValues accessor = context.reader().getBinaryDocValues(field);
                            Boolean is_value = false;

                            @Override
                            public void setDocument(int docId) {
                                try {
                                    accessor.advanceExact(docId);
                                    is_value = true;
                                } catch (IOException e) {
                                    is_value = false;
                                }
                            }

                            @Override
                            public double runAsDouble() {
                                if (!is_value) return 0;
                                final byte[] bytes;
                                try {
                                     bytes = accessor.binaryValue().bytes;
                                } catch (IOException e) {
                                     return 0;
                                }

                                final int input_vector_size = inputVector.length;

                                final ByteArrayDataInput doc_vector = new ByteArrayDataInput(bytes);
                                doc_vector.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls
                                final int doc_vector_length = doc_vector.readVInt(); // returns the number of bytes to read
                                if(doc_vector_length != input_vector_size * DOUBLE_SIZE) {
                                    return 0.0;
                                }
                                final int position = doc_vector.getPosition();
                                final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, doc_vector_length).asDoubleBuffer();

                                final double[] docVector = new double[input_vector_size];
                                doubleBuffer.get(docVector);

                                double score = 0;
                                for (int i = 0; i < input_vector_size; i++) {
                                    score += docVector[i] * inputVector[i];
                                }
                                return score;
                            }
                        };
                    }

                    @Override
                    public boolean needs_score() {
                        return false;
                    }
                };
                return context.factoryClazz.cast(factory);
            }
            throw new IllegalArgumentException("Unknown script name " + scriptSource);
        }

        @Override
        public void close() {
            // optionally close resources
        }

    }
    // end::expert_engine

}