/*
 *
 * ****************
 * This file is part of sparkboost software package (https://github.com/tizfa/sparkboost).
 *
 * Copyright 2016 Tiziano Fagni ([email protected])
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * ******************
 */

package it.tizianofagni.sparkboost;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;

/**
 * A Spark implementation of AdaBoost.MH learner.<br/><br/>
 * The original article describing the algorithm can be found at
 * <a href="http://link.springer.com/article/10.1023%2FA%3A1007649029923">http://link.springer.com/article/10.1023%2FA%3A1007649029923</a>.
 *
 * @author Tiziano Fagni ([email protected])
 */
public class AdaBoostMHLearner {

    private final JavaSparkContext sc;
    /**
     * The number of iterations.
     */
    private int numIterations;

    /**
     * The number of partitions while analyzing
     * an RDD of type {@link JavaRDD<MultilabelPoint>}.
     */
    private int numDocumentsPartitions;

    /**
     * The number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.FeatureDocuments>}
     */
    private int numFeaturesPartitions;

    /**
     * The number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.LabelDocuments>}
     */
    private int numLabelsPartitions;


    public AdaBoostMHLearner(JavaSparkContext sc) {
        if (sc == null)
            throw new NullPointerException("The SparkContext is 'null'");
        this.sc = sc;
        this.numIterations = 200;
        this.numDocumentsPartitions = -1;
        this.numFeaturesPartitions = -1;
        this.numLabelsPartitions = -1;
    }

    /**
     * Get the number of partitions while analyzing
     * an RDD of type {@link JavaRDD<MultilabelPoint>}.
     *
     * @return The number of partitions while analyzing
     * an RDD of type {@link JavaRDD<MultilabelPoint>}.
     */
    public int getNumDocumentsPartitions() {
        return numDocumentsPartitions;
    }

    /**
     * Set the number of partitions while analyzing
     * an RDD of type {@link JavaRDD<MultilabelPoint>}.
     *
     * @param numDocumentsPartitions The number of partitions.
     */
    public void setNumDocumentsPartitions(int numDocumentsPartitions) {
        this.numDocumentsPartitions = numDocumentsPartitions;
    }


    /**
     * Get the number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.FeatureDocuments>}.
     *
     * @return The number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.FeatureDocuments>}
     */
    public int getNumFeaturesPartitions() {
        return numFeaturesPartitions;
    }

    /**
     * Set the number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.FeatureDocuments>}.
     *
     * @param numFeaturesPartitions The number of partitions.
     */
    public void setNumFeaturesPartitions(int numFeaturesPartitions) {
        this.numFeaturesPartitions = numFeaturesPartitions;
    }


    /**
     * Get the number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.LabelDocuments>}.
     *
     * @return The number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.LabelDocuments>}.
     */
    public int getNumLabelsPartitions() {
        return numLabelsPartitions;
    }

    /**
     * Set the number of partitions while analyzing
     * an RDD of type {@link JavaRDD<DataUtils.LabelDocuments>}.
     *
     * @param numLabelsPartitions The number of partitions.
     */
    public void setNumLabelsPartitions(int numLabelsPartitions) {
        this.numLabelsPartitions = numLabelsPartitions;
    }

    /**
     * Build a new classifier by analyzing the training data available in the
     * specified documents set.
     *
     * @param docs The set of documents used as training data.
     * @return A new AdaBoost.MH classifier.
     */
    public BoostClassifier buildModel(JavaRDD<MultilabelPoint> docs) {
        if (docs == null)
            throw new NullPointerException("The set of input documents is 'null'");


        // Repartition documents.
        Logging.l().info("Load initial data and generating internal data representations...");
        docs = docs.repartition(getNumDocumentsPartitions());
        docs = docs.persist(StorageLevel.MEMORY_AND_DISK_SER());
        Logging.l().info("Docs: num partitions " + docs.partitions().size());

        int numDocs = DataUtils.getNumDocuments(docs);
        int numLabels = DataUtils.getNumLabels(docs);
        JavaRDD<DataUtils.LabelDocuments> labelDocuments = DataUtils.getLabelDocuments(docs);

        // Repartition labels.
        labelDocuments = labelDocuments.repartition(getNumLabelsPartitions());
        labelDocuments = labelDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
        Logging.l().info("Labels: num partitions " + labelDocuments.partitions().size());

        // Repartition features.
        JavaRDD<DataUtils.FeatureDocuments> featureDocuments = DataUtils.getFeatureDocuments(docs);
        featureDocuments = featureDocuments.repartition(getNumFeaturesPartitions());
        featureDocuments = featureDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
        Logging.l().info("Features: num partitions " + featureDocuments.partitions().size());
        Logging.l().info("Ok, done!");

        WeakHypothesis[] computedWH = new WeakHypothesis[numIterations];
        double[][] localDM = initDistributionMatrix(numLabels, numDocs);
        for (int i = 0; i < numIterations; i++) {

            // Generate new weak hypothesis.
            WeakHypothesis localWH = learnWeakHypothesis(localDM, labelDocuments, featureDocuments);

            // Update distribution matrix with the new hypothesis.
            updateDistributionMatrix(sc, docs, localDM, localWH);

            // Save current generated weak hypothesis.
            computedWH[i] = localWH;

            Logging.l().info("Completed iteration " + (i + 1));
        }

        Logging.l().info("Model built!");

        return new BoostClassifier(computedWH);
    }


    /**
     * Build a new classifier by analyzing the training data available in the
     * specified input file. The file must be in LibSvm data format.
     *
     * @param libSvmFile    The input file containing the documents used as training data.
     * @param labels0Based  True if the label indexes specified in the input file are 0-based (i.e. the first label ID is 0), false if they
     *                      are 1-based (i.e. the first label ID is 1).
     * @param binaryProblem True if the input file contains data for a binary problem, false if the input file contains data for a multiclass multilabel
     *                      problem.
     * @return A new AdaBoost.MH classifier.
     */
    public BoostClassifier buildModel(String libSvmFile, boolean labels0Based, boolean binaryProblem) {
        if (libSvmFile == null || libSvmFile.isEmpty())
            throw new IllegalArgumentException("The input file is 'null' or empty");

        int minNumPartitions = 8;
        if (this.numDocumentsPartitions != -1)
            minNumPartitions = this.numDocumentsPartitions;
        JavaRDD<MultilabelPoint> docs = DataUtils.loadLibSvmFileFormatData(sc, libSvmFile, labels0Based, binaryProblem, minNumPartitions);
        if (this.numDocumentsPartitions == -1)
            this.numDocumentsPartitions = sc.defaultParallelism();
        if (this.numFeaturesPartitions == -1)
            this.numFeaturesPartitions = sc.defaultParallelism();
        if (this.numLabelsPartitions == -1)
            this.numLabelsPartitions = sc.defaultParallelism();
        Logging.l().info("Docs partitions = " + this.numDocumentsPartitions + ", feats partitions = " + this.numFeaturesPartitions + ", labels partitions = " + this.getNumLabelsPartitions());
        return buildModel(docs);
    }

    protected void updateDistributionMatrix(JavaSparkContext sc, JavaRDD<MultilabelPoint> docs, double[][] localDM, WeakHypothesis localWH) {
        Broadcast<WeakHypothesis> distWH = sc.broadcast(localWH);
        Broadcast<double[][]> distDM = sc.broadcast(localDM);
        JavaRDD<DMPartialResult> partialResults = docs.map(doc -> {
            int[] validFeatures = doc.getFeatures().indices();
            HashMap<Integer, Integer> dictFeatures = new HashMap<>();
            for (int featID : validFeatures)
                dictFeatures.put(featID, featID);
            HashMap<Integer, Integer> dictLabels = new HashMap<>();
            for (int idx = 0; idx < doc.getLabels().length; idx++)
                dictLabels.put(doc.getLabels()[idx], doc.getLabels()[idx]);

            double[][] dm = distDM.getValue();
            WeakHypothesis wh = distWH.getValue();
            double[] labelsRes = new double[dm.length];
            for (int labelID = 0; labelID < dm.length; labelID++) {
                float catValue = 1;
                if (dictLabels.containsKey(labelID)) {
                    catValue = -1;
                }

                // Compute the weak hypothesis value.
                double value = 0;
                WeakHypothesis.WeakHypothesisData v = wh.getLabelData(labelID);
                int pivot = v.getFeatureID();
                if (dictFeatures.containsKey(pivot))
                    value = v.getC1();
                else
                    value = v.getC0();


                double partialRes = dm[labelID][doc.getPointID()] * Math.exp(catValue * value);
                labelsRes[labelID] = partialRes;
            }

            return new DMPartialResult(doc.getPointID(), labelsRes);
        });

        Iterator<DMPartialResult> itResults = partialResults.toLocalIterator();
        // Update partial results.
        double normalization = 0;
        while (itResults.hasNext()) {
            DMPartialResult r = itResults.next();
            for (int labelID = 0; labelID < localDM.length; labelID++) {
                localDM[labelID][r.docID] = r.labelsRes[labelID];
                normalization += localDM[labelID][r.docID];
            }
        }

        // Normalize all values.
        for (int labelID = 0; labelID < localDM.length; labelID++) {
            for (int docID = 0; docID < localDM[0].length; docID++) {
                localDM[labelID][docID] = localDM[labelID][docID] / normalization;
            }
        }
    }

    protected double[][] initDistributionMatrix(int numLabels, int numDocs) {
        double[][] dist = new double[numLabels][numDocs];

        // Initialize matrix with uniform distribution.
        float uniformValue = 1 / ((float) numDocs * numLabels);
        for (int label = 0; label < dist.length; label++) {
            for (int doc = 0; doc < dist[0].length; doc++) {
                dist[label][doc] = uniformValue;
            }
        }
        return dist;
    }

    protected WeakHypothesis learnWeakHypothesis(double[][] localDM, JavaRDD<DataUtils.LabelDocuments> labelDocuments, JavaRDD<DataUtils.FeatureDocuments> featureDocuments) {
        int labelsSize = localDM.length;
        int docsSize = localDM[0].length;

        // Examples positive for a given label.
        double[] local_weight_b1 = new double[labelsSize];

        // Examples negative for a given label
        double[] local_weight_bminus_1 = new double[labelsSize];


        // Initialize structures.
        for (int pos = 0; pos < labelsSize; pos++) {
            local_weight_b1[pos] = 0;
            local_weight_bminus_1[pos] = 0;
        }

        Iterator<DataUtils.LabelDocuments> itlabels = labelDocuments.toLocalIterator();
        while (itlabels.hasNext()) {
            DataUtils.LabelDocuments la = itlabels.next();
            int labelID = la.getLabelID();
            assert (labelID != -1);
            for (int idx = 0; idx < la.getDocuments().length; idx++) {
                int docID = la.getDocuments()[idx];
                assert (docID != -1);
                double distValue = localDM[labelID][docID];
                local_weight_b1[labelID] += distValue;
            }
        }


        // Compute global weight for categories.
        for (int labelID = 0; labelID < labelsSize; labelID++) {
            double global = 0;

            // Iterate over all distribution matrix.
            for (int docID = 0; docID < docsSize; docID++) {
                double distValue = localDM[labelID][docID];
                global += distValue;
            }

            local_weight_bminus_1[labelID] = global - local_weight_b1[labelID];
        }


        Broadcast<double[][]> distDM = sc.broadcast(localDM);
        Broadcast<double[]> weight_b1 = sc.broadcast(local_weight_b1);
        Broadcast<double[]> weight_bminus_1 = sc.broadcast(local_weight_bminus_1);

        // Process all features.
        WeakHypothesisResults res = featureDocuments.map(feat -> {
            double[][] dm = distDM.getValue();
            double epsilon = 1.0 / (double) (dm.length * dm[0].length);
            int numLabels = dm.length;
            double[] weight_b1_x0 = new double[numLabels];
            double[] weight_b1_x1 = new double[numLabels];
            double[] weight_bminus_1_x0 = new double[numLabels];
            double[] weight_bminus_1_x1 = new double[numLabels];
            double[] computedC0 = new double[numLabels];
            double[] computedC1 = new double[numLabels];
            double[] computedZs = new double[numLabels];
            int realFeatID = feat.getFeatureID();
            int pivot = realFeatID;
            // Initialize structures.
            for (int pos = 0; pos < numLabels; pos++) {
                weight_b1_x0[pos] = 0;
                weight_b1_x1[pos] = 0;
                weight_bminus_1_x0[pos] = 0;
                weight_bminus_1_x1[pos] = 0;
            }


            for (int docIdx = 0; docIdx < feat.getDocuments().length; docIdx++) {
                int docID = feat.getDocuments()[docIdx];
                int[] labels = feat.getLabels()[docIdx];
                HashMap<Integer, Integer> catDict = new HashMap<Integer, Integer>();
                for (int labelIdx = 0; labelIdx < labels.length; labelIdx++) {
                    int currentCatID = labels[labelIdx];
                    double distValue = dm[currentCatID][docID];
                    // Feature and category compare together.
                    weight_b1_x1[currentCatID] += distValue;
                    catDict.put(currentCatID, currentCatID);
                }
                for (int currentCatID = 0; currentCatID < numLabels; currentCatID++) {
                    if (catDict.containsKey(currentCatID))
                        continue;
                    double distValue = dm[currentCatID][docID];
                    // Feature compare on document and category not.
                    weight_bminus_1_x1[currentCatID] += distValue;
                }
            }

            // Compute the remaining values.
            for (int catID = 0; catID < numLabels; catID++) {
                double v = weight_b1.getValue()[catID] - weight_b1_x1[catID];
                if (v < 0)
                    v = 0;

                weight_b1_x0[catID] = v;

                v = weight_bminus_1.getValue()[catID] - weight_bminus_1_x1[catID];
                // Adjust round errors.
                if (v < 0)
                    v = 0;
                weight_bminus_1_x0[catID] = v;
            }

            // Compute current Z_s.
            double Z_s = 0;
            for (int catID = 0; catID < numLabels; catID++) {
                assert (weight_b1_x0[catID] >= 0);
                assert (weight_bminus_1_x0[catID] >= 0);
                assert (weight_b1_x1[catID] >= 0);
                assert (weight_bminus_1_x1[catID] >= 0);

                double first = Math.sqrt(weight_b1_x0[catID]
                        * weight_bminus_1_x0[catID]);
                double second = Math.sqrt(weight_b1_x1[catID]
                        * weight_bminus_1_x1[catID]);
                Z_s += (first + second);
                double c0 = Math.log((weight_b1_x0[catID] + epsilon)
                        / (weight_bminus_1_x0[catID] + epsilon)) / 2.0;
                double c1 = Math.log((weight_b1_x1[catID] + epsilon)
                        / (weight_bminus_1_x1[catID] + epsilon)) / 2.0;
                computedC0[catID] = c0;
                computedC1[catID] = c1;
            }
            Z_s = 2 * Z_s;

            return new WeakHypothesisResults(pivot, computedC0, computedC1, Z_s);
        }).reduce((ph1, ph2) -> {
            int pivot = -1;
            double[] c0 = new double[ph1.getC0().length];
            double[] c1 = new double[ph1.getC0().length];
            double z_s = 0;
            if (ph1.getZ_s() < ph2.getZ_s()) {
                pivot = ph1.pivot;
                z_s = ph1.getZ_s();
                for (int i = 0; i < ph1.getC0().length; i++) {
                    c0[i] = ph1.getC0()[i];
                    c1[i] = ph1.getC1()[i];
                }
            } else {
                pivot = ph2.pivot;
                z_s = ph2.getZ_s();
                for (int i = 0; i < ph2.getC0().length; i++) {
                    c0[i] = ph2.getC0()[i];
                    c1[i] = ph2.getC1()[i];
                }
            }

            return new WeakHypothesisResults(pivot, c0, c1, z_s);
        });

        WeakHypothesis wh = new WeakHypothesis(labelsSize);
        for (int i = 0; i < labelsSize; i++) {
            wh.setLabelData(i, new WeakHypothesis.WeakHypothesisData(i, res.getPivot(), res.getC0()[i], res.getC1()[i]));
        }
        return wh;
    }

    /**
     * Get the number of iterations used while building classifier.
     *
     * @return The number of iterations used while building classifier.
     */
    public int getNumIterations() {
        return numIterations;
    }

    /**
     * Set the number of iterations to use while building a new classifier.
     *
     * @param numIterations The number of iterations to use.
     */
    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }


    private static class WeakHypothesisResults implements Serializable {
        private final int pivot;
        private final double[] c0;
        private final double[] c1;
        private final double z_s;

        public WeakHypothesisResults(int pivot, double[] c0, double[] c1, double z_s) {
            this.pivot = pivot;
            this.c0 = c0;
            this.c1 = c1;
            this.z_s = z_s;
        }

        public int getPivot() {
            return pivot;
        }

        public double[] getC0() {
            return c0;
        }

        public double[] getC1() {
            return c1;
        }

        public double getZ_s() {
            return z_s;
        }
    }
}