/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.classifier;

import hivemall.annotations.Experimental;
import hivemall.annotations.VisibleForTesting;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;

import java.util.ArrayList;
import java.util.List;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

/**
 * Degree-2 polynomial kernel expansion Passive Aggressive.
 * 
 * <pre>
 * Hideki Isozaki and Hideto Kazawa: Efficient Support Vector Classifiers for Named Entity Recognition, Proc.COLING, 2002
 * </pre>
 * 
 * @since v0.5-rc.1
 */
@Description(name = "train_kpa",
        value = "_FUNC_(array<string|int|bigint> features, int label [, const string options])"
                + " - returns a relation <h int, hk int, float w0, float w1, float w2, float w3>")
@Experimental
public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClassifierUDTF {

    // ------------------------------------
    // Hyper parameters
    private float _pkc;
    // Algorithm
    private Algorithm _algo;

    // ------------------------------------
    // Model parameters

    private float _w0;
    private Int2FloatMap _w1;
    private Int2FloatMap _w2;
    private Int2FloatMap _w3;

    // ------------------------------------

    private float _loss;

    public KernelExpansionPassiveAggressiveUDTF() {}

    @VisibleForTesting
    float getLoss() {//only used for testing purposes at the moment
        return _loss;
    }

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("pkc", true,
            "Constant c inside polynomial kernel K = (dot(xi,xj) + c)^2 [default 1.0]");
        opts.addOption("algo", "algorithm", true,
            "Algorithm for calculating loss [pa, pa1 (default), pa2]");
        opts.addOption("c", "aggressiveness", true,
            "Aggressiveness parameter C for PA-1 and PA-2 [default 1.0]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        float pkc = 1.f;
        float c = 1.f;
        String algo = "pa1";

        final CommandLine cl = super.processOptions(argOIs);
        if (cl != null) {
            String pkc_str = cl.getOptionValue("pkc");
            if (pkc_str != null) {
                pkc = Float.parseFloat(pkc_str);
            }
            String c_str = cl.getOptionValue("c");
            if (c_str != null) {
                c = Float.parseFloat(c_str);
                if (c <= 0.f) {
                    throw new UDFArgumentException(
                        "Aggressiveness parameter C must be C > 0: " + c);
                }
            }
            algo = cl.getOptionValue("algo", algo);
        }

        if ("pa1".equalsIgnoreCase(algo)) {
            this._algo = new PA1(c);
        } else if ("pa2".equalsIgnoreCase(algo)) {
            this._algo = new PA2(c);
        } else if ("pa".equalsIgnoreCase(algo)) {
            this._algo = new PA();
        } else {
            throw new UDFArgumentException("Unsupported algorithm: " + algo);
        }
        this._pkc = pkc;

        return cl;
    }

    interface Algorithm {
        float eta(final float loss, @Nonnull final PredictionResult margin);
    }

    static class PA implements Algorithm {

        PA() {}

        @Override
        public float eta(float loss, PredictionResult margin) {
            return loss / margin.getSquaredNorm();
        }
    }

    static class PA1 implements Algorithm {
        private final float c;

        PA1(float c) {
            this.c = c;
        }

        @Override
        public float eta(float loss, PredictionResult margin) {
            float squared_norm = margin.getSquaredNorm();
            float eta = loss / squared_norm;
            return Math.min(c, eta);
        }
    }

    static class PA2 implements Algorithm {
        private final float c;

        PA2(float c) {
            this.c = c;
        }

        @Override
        public float eta(float loss, PredictionResult margin) {
            float squared_norm = margin.getSquaredNorm();
            float eta = loss / (squared_norm + (0.5f / c));
            return eta;
        }
    }

    @Override
    protected PredictionModel createModel() {
        this._w0 = 0.f;
        this._w1 = new Int2FloatOpenHashMap(16384);
        _w1.defaultReturnValue(0.f);
        this._w2 = new Int2FloatOpenHashMap(16384);
        _w2.defaultReturnValue(0.f);
        this._w3 = new Int2FloatOpenHashMap(16384);
        _w3.defaultReturnValue(0.f);

        return null;
    }

    @Override
    protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

        fieldNames.add("h");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("w0");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("w1");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("w2");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        fieldNames.add("hk");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("w3");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Nullable
    FeatureValue[] parseFeatures(@Nonnull final List<?> features) {
        final int size = features.size();
        if (size == 0) {
            return null;
        }

        final FeatureValue[] featureVector = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object f = features.get(i);
            if (f == null) {
                continue;
            }
            FeatureValue fv = FeatureValue.parse(f, true);
            featureVector[i] = fv;
        }
        return featureVector;
    }

    @Override
    protected void train(@Nonnull final FeatureValue[] features, final int label) {
        final float y = label > 0 ? 1.f : -1.f;

        PredictionResult margin = calcScoreWithKernelAndNorm(features);
        float p = margin.getScore();
        float loss = LossFunctions.hingeLoss(p, y); // 1.0 - y * p
        this._loss = loss;

        if (loss > 0.f) { // y * p < 1
            updateKernel(y, loss, margin, features);
        }
    }

    @Override
    float predict(@Nonnull final FeatureValue[] features) {
        float score = 0.f;

        for (int i = 0; i < features.length; ++i) {
            if (features[i] == null) {
                continue;
            }
            int h = features[i].getFeatureAsInt();
            float w1 = _w1.get(h);
            float w2 = _w2.get(h);
            double xi = features[i].getValue();
            double xx = xi * xi;
            score += w1 * xi;
            score += w2 * xx;
            for (int j = i + 1; j < features.length; ++j) {
                int k = features[j].getFeatureAsInt();
                int hk = HashFunction.hash(h, k, true);
                float w3 = _w3.get(hk);
                double xj = features[j].getValue();
                score += xi * xj * w3;
            }
        }

        return score;
    }

    @Nonnull
    final PredictionResult calcScoreWithKernelAndNorm(@Nonnull final FeatureValue[] features) {
        float score = _w0;
        float norm = 0.f;
        for (int i = 0; i < features.length; ++i) {
            if (features[i] == null) {
                continue;
            }
            int h = features[i].getFeatureAsInt();
            float w1 = _w1.get(h);
            float w2 = _w2.get(h);
            double xi = features[i].getValue();
            double xx = xi * xi;
            score += w1 * xi;
            score += w2 * xx;
            norm += xx;
            for (int j = i + 1; j < features.length; ++j) {
                int k = features[j].getFeatureAsInt();
                int hk = HashFunction.hash(h, k, true);
                float w3 = _w3.get(hk);
                double xj = features[j].getValue();
                score += xi * xj * w3;
            }
        }
        return new PredictionResult(score).squaredNorm(norm);
    }

    protected void updateKernel(final float label, final float loss,
            @Nonnull final PredictionResult margin, @Nonnull final FeatureValue[] features) {
        float eta = _algo.eta(loss, margin);
        float coeff = eta * label;
        expandKernel(features, coeff);
    }

    private void expandKernel(@Nonnull final FeatureValue[] supportVector, final float alpha) {
        final float pkc = _pkc;
        // W0 += α c^2
        this._w0 += alpha * pkc * pkc;

        for (int i = 0; i < supportVector.length; ++i) {
            final FeatureValue si = supportVector[i];
            final int h = si.getFeatureAsInt();
            float Zih = si.getValueAsFloat();

            float alphaZih = alpha * Zih;
            final float alphaZih2 = alphaZih * 2.f;

            // W1[h] += 2 c α Zi[h]
            _w1.put(h, _w1.get(h) + pkc * alphaZih2);
            // W2[h] += α Zi[h]^2
            _w2.put(h, _w2.get(h) + alphaZih * Zih);

            for (int j = i + 1; j < supportVector.length; ++j) {
                FeatureValue sj = supportVector[j];
                int k = sj.getFeatureAsInt();
                int hk = HashFunction.hash(h, k, true);
                float Zjk = sj.getValueAsFloat();

                // W3 += 2 α Zi[h] Zi[k]
                _w3.put(hk, _w3.get(hk) + alphaZih2 * Zjk);
            }
        }
    }

    @Override
    public void close() throws HiveException {
        final IntWritable h = new IntWritable(0); // row[0]
        final FloatWritable w0 = new FloatWritable(_w0); // row[1]
        final FloatWritable w1 = new FloatWritable(); // row[2]
        final FloatWritable w2 = new FloatWritable(); // row[3]
        final IntWritable hk = new IntWritable(0); // row[4]
        final FloatWritable w3 = new FloatWritable(); // row[5]
        final Object[] row = new Object[] {h, w0, null, null, null, null};
        forward(row); // 0(f), w0
        row[1] = null;

        row[2] = w1;
        row[3] = w2;
        final Int2FloatMap w2map = _w2;
        for (Int2FloatMap.Entry e : Fastutil.fastIterable(_w1)) {
            int k = e.getIntKey();
            Preconditions.checkArgument(k > 0, HiveException.class);
            h.set(k);
            w1.set(e.getFloatValue());
            w2.set(w2map.get(k));
            forward(row); // h(f), w1, w2
        }
        this._w1 = null;
        this._w2 = null;

        row[0] = null;
        row[2] = null;
        row[3] = null;
        row[4] = hk;
        row[5] = w3;

        _w3.int2FloatEntrySet();
        for (Int2FloatMap.Entry e : Fastutil.fastIterable(_w3)) {
            int k = e.getIntKey();
            Preconditions.checkArgument(k > 0, HiveException.class);
            hk.set(k);
            w3.set(e.getFloatValue());
            forward(row); // hk(f), w3
        }
        this._w3 = null;
    }

}