package edu.neu.ccs.pyramid.multilabel_classification.hmlgb;


import com.fasterxml.jackson.databind.ObjectMapper;
import edu.neu.ccs.pyramid.configuration.Config;
import edu.neu.ccs.pyramid.dataset.*;

import edu.neu.ccs.pyramid.eval.Accuracy;
import edu.neu.ccs.pyramid.eval.MacroAveragedMeasures;
import edu.neu.ccs.pyramid.eval.Overlap;
import edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis;
import org.apache.commons.lang3.time.StopWatch;
import org.apache.mahout.math.Vector;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


public class HMLGradientBoostingTest {
    private static final Config config = new Config("config/local.properties");
    private static final String DATASETS = config.getString("input.datasets");
    private static final String TMP = config.getString("output.tmp");
    public static void main(String[] args) throws Exception{
//       spam_all();
//        test2_all();
//        test3_all();
//        test4();
//        test3_load();
        test5();
    }

    static void spam_all() throws Exception{
        spam_build();
        spam_load();
    }

    static void test2_all() throws Exception{
        test2_build();
        test2_load();
    }

    static void test3_all() throws Exception{
        test3_build();
        test3_load();
    }

    static void spam_load() throws Exception{
        ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS,"/spam/trec_data/test.trec"),
                DataSetType.CLF_DENSE, true);
        int numDataPoints = singleLabeldataSet.getNumDataPoints();
        int numFeatures = singleLabeldataSet.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder()
        .numDataPoints(numDataPoints).numFeatures(numFeatures)
        .numClasses(2).build();
        int[] labels = singleLabeldataSet.getLabels();
        for (int i=0;i<numDataPoints;i++){
            dataSet.addLabel(i,labels[i]);
            for (int j=0;j<numFeatures;j++){
                double value = singleLabeldataSet.getRow(i).get(j);
                dataSet.setFeatureValue(i,j,value);
            }
        }

        HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP,"/hmlgb/boosting.ser"));
        System.out.println("accuracy="+Accuracy.accuracy(boosting,dataSet));
        System.out.println("macro-averaged:");
        System.out.println(new MacroAveragedMeasures(boosting,dataSet));
    }

    static void spam_build() throws Exception{


        ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS,"/spam/trec_data/train.trec"),
                DataSetType.CLF_DENSE, true);
        int numDataPoints = singleLabeldataSet.getNumDataPoints();
        int numFeatures = singleLabeldataSet.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder()
                .numDataPoints(numDataPoints).numFeatures(numFeatures)
                .numClasses(2).build();
        int[] labels = singleLabeldataSet.getLabels();
        for (int i=0;i<numDataPoints;i++){
            dataSet.addLabel(i,labels[i]);
            for (int j=0;j<numFeatures;j++){
                double value = singleLabeldataSet.getRow(i).get(j);
                dataSet.setFeatureValue(i,j,value);
            }
        }



        List<MultiLabel> assignments = new ArrayList<>();
        assignments.add(new MultiLabel().addLabel(0));
        assignments.add(new MultiLabel().addLabel(1));
        HMLGradientBoosting boosting = new HMLGradientBoosting(2,assignments);


        HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet)
                .numLeaves(7).learningRate(0.1).numSplitIntervals(50).minDataPerLeaf(1)
                .dataSamplingRate(1).featureSamplingRate(1).build();
        System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));

        HMLGBTrainer trainer = new HMLGBTrainer(trainConfig,boosting);

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        for (int round =0;round<200;round++){
            System.out.println("round="+round);
            trainer.iterate();
            System.out.println("accuracy="+Accuracy.accuracy(boosting,dataSet));
//            System.out.println(Arrays.toString(boosting.getGradients(0)));
//            System.out.println(Arrays.toString(boosting.getGradients(1)));

        }
        stopWatch.stop();
        System.out.println(stopWatch);
        System.out.println(boosting);
//        for (int i=0;i<numDataPoints;i++){
//            FeatureRow featureRow = dataSet.getRow(i);
//            System.out.println("label="+dataSet.getMultiLabels()[i]);
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
//            System.out.println(boosting.predict(featureRow));
//        }
        System.out.println("accuracy");
        System.out.println(Accuracy.accuracy(boosting,dataSet));
        System.out.println("macro-averaged:");
        System.out.println(new MacroAveragedMeasures(boosting,dataSet));



        boosting.serialize(new File(TMP,"/hmlgb/boosting.ser"));

    }

    /**
     * add a fake label in spam data set, if x=spam and x_0<0.1, also label it as 2
     * @throws Exception
     */
    static void test2_build() throws Exception{


        ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS,"/spam/trec_data/train.trec"),
                DataSetType.CLF_DENSE, true);
        int numDataPoints = singleLabeldataSet.getNumDataPoints();
        int numFeatures = singleLabeldataSet.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder()
                .numDataPoints(numDataPoints).numFeatures(numFeatures)
                .numClasses(3).build();
        int[] labels = singleLabeldataSet.getLabels();
        for (int i=0;i<numDataPoints;i++){
            dataSet.addLabel(i,labels[i]);
            if (labels[i]==1 && singleLabeldataSet.getRow(i).get(0)<0.1){
                dataSet.addLabel(i,2);
            }
            for (int j=0;j<numFeatures;j++){
                double value = singleLabeldataSet.getRow(i).get(j);
                dataSet.setFeatureValue(i,j,value);
            }
        }



        List<MultiLabel> assignments = new ArrayList<>();
        assignments.add(new MultiLabel().addLabel(0));
        assignments.add(new MultiLabel().addLabel(1));
        assignments.add(new MultiLabel().addLabel(1).addLabel(2));
        HMLGradientBoosting boosting = new HMLGradientBoosting(3,assignments);


        HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet)
                .numLeaves(10).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2)
                .dataSamplingRate(1).featureSamplingRate(1).build();
        System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));


        HMLGBTrainer trainer = new HMLGBTrainer(trainConfig,boosting);

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        for (int round =0;round<30;round++){
            System.out.println("round="+round);
            trainer.iterate();
            System.out.println("accuracy="+Accuracy.accuracy(boosting,dataSet));
//            System.out.println(Arrays.toString(boosting.getGradients(0)));
//            System.out.println(Arrays.toString(boosting.getGradients(1)));

        }
        stopWatch.stop();
        System.out.println(stopWatch);
        System.out.println(boosting);
        for (int i=0;i<numDataPoints;i++){
            Vector featureRow = dataSet.getRow(i);
            MultiLabel label = dataSet.getMultiLabels()[i];
            MultiLabel prediction = boosting.predict(featureRow);
//            System.out.println("label="+label);
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
//            System.out.println("prediction="+prediction);
//            if (!MultiLabel.equivalent(label,prediction)){
//                System.out.println(i);
//                System.out.println("label="+label);
//                System.out.println("prediction="+prediction);
//            }
        }
        System.out.println("accuracy");
        System.out.println(Accuracy.accuracy(boosting,dataSet));
        boosting.serialize(new File(TMP,"/hmlgb/boosting.ser"));

    }

    static void test2_load() throws Exception{


        ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS,"/spam/trec_data/test.trec"),
                DataSetType.CLF_DENSE, true);
        int numDataPoints = singleLabeldataSet.getNumDataPoints();
        int numFeatures = singleLabeldataSet.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder()
                .numDataPoints(numDataPoints).numFeatures(numFeatures)
                .numClasses(3).build();
        int[] labels = singleLabeldataSet.getLabels();
        for (int i=0;i<numDataPoints;i++){
            dataSet.addLabel(i,labels[i]);
            if (labels[i]==1 && singleLabeldataSet.getRow(i).get(0)<0.1){
                dataSet.addLabel(i,2);
            }
            for (int j=0;j<numFeatures;j++){
                double value = singleLabeldataSet.getRow(i).get(j);
                dataSet.setFeatureValue(i,j,value);
            }
        }


        HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP,"/hmlgb/boosting.ser"));
        System.out.println(Accuracy.accuracy(boosting,dataSet));
        for (int i=0;i<numDataPoints;i++){
            Vector featureRow = dataSet.getRow(i);
            MultiLabel label = dataSet.getMultiLabels()[i];
            MultiLabel prediction = boosting.predict(featureRow);
//            System.out.println("label="+label);
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
//            System.out.println("prediction="+prediction);
//            if (!MultiLabel.equivalent(label,prediction)){
//                System.out.println(i);
//                System.out.println("label="+label);
//                System.out.println("prediction="+prediction);
//            }
        }


    }


    /**
     * add 2 fake labels in spam data set,
     * if x=spam and x_0<0.1, also label it as 2
     * if x=spam and x_1<0.1, also label it as 3
     * @throws Exception
     */
    static void test3_build() throws Exception{


        ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS,"spam/trec_data/train.trec"),
                DataSetType.CLF_DENSE, true);
        int numDataPoints = singleLabeldataSet.getNumDataPoints();
        int numFeatures = singleLabeldataSet.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder()
                .numDataPoints(numDataPoints).numFeatures(numFeatures)
                .numClasses(4).build();
        int[] labels = singleLabeldataSet.getLabels();
        for (int i=0;i<numDataPoints;i++){
            dataSet.addLabel(i,labels[i]);
            if (labels[i]==1 && singleLabeldataSet.getRow(i).get(0)<0.1){
                dataSet.addLabel(i,2);
            }
            if (labels[i]==1 && singleLabeldataSet.getRow(i).get(1)<0.1){
                dataSet.addLabel(i,3);
            }
            for (int j=0;j<numFeatures;j++){
                double value = singleLabeldataSet.getRow(i).get(j);
                dataSet.setFeatureValue(i,j,value);
            }
        }


        List<String> extLabels = new ArrayList<>();
        extLabels.add("non_spam");
        extLabels.add("spam");
        extLabels.add("fake2");
        extLabels.add("fake3");
        LabelTranslator labelTranslator = new LabelTranslator(extLabels);
        dataSet.setLabelTranslator(labelTranslator);

        List<MultiLabel> assignments = new ArrayList<>();
        assignments.add(new MultiLabel().addLabel(0));
        assignments.add(new MultiLabel().addLabel(1));
        assignments.add(new MultiLabel().addLabel(1).addLabel(2));
        assignments.add(new MultiLabel().addLabel(1).addLabel(3));
        assignments.add(new MultiLabel().addLabel(1).addLabel(2).addLabel(3));
        HMLGradientBoosting boosting = new HMLGradientBoosting(4,assignments);


        HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet)
                .numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2)
                .dataSamplingRate(1).featureSamplingRate(1).build();
        System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));


        HMLGBTrainer trainer = new HMLGBTrainer(trainConfig,boosting);

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        for (int round =0;round<100;round++){
            System.out.println("round="+round);
            trainer.iterate();
            System.out.println("accuracy="+Accuracy.accuracy(boosting,dataSet));
//            System.out.println(Arrays.toString(boosting.getGradients(0)));
//            System.out.println(Arrays.toString(boosting.getGradients(1)));
//            System.out.println(Arrays.toString(boosting.getGradients(2)));
//            System.out.println(Arrays.toString(boosting.getGradients(3)));

        }
        stopWatch.stop();
        System.out.println(stopWatch);
//        System.out.println(boosting);
        for (int i=0;i<numDataPoints;i++){
            Vector featureRow = dataSet.getRow(i);
            MultiLabel label = dataSet.getMultiLabels()[i];
            MultiLabel prediction = boosting.predict(featureRow);
//            System.out.println("label="+label);
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
//            System.out.println("prediction="+prediction);
            if (!label.equals(prediction)){
                System.out.println(i);
                System.out.println("label="+label);
                System.out.println("prediction="+prediction);
            }
        }
        System.out.println("accuracy");
        System.out.println(Accuracy.accuracy(boosting,dataSet));
        boosting.serialize(new File(TMP,"/hmlgb/boosting.ser"));


    }

    static void test3_load() throws Exception{



        ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS,"/spam/trec_data/test.trec"),
                DataSetType.CLF_DENSE, true);
        int numDataPoints = singleLabeldataSet.getNumDataPoints();
        int numFeatures = singleLabeldataSet.getNumFeatures();
        MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder()
                .numDataPoints(numDataPoints).numFeatures(numFeatures)
                .numClasses(4).build();
        int[] labels = singleLabeldataSet.getLabels();
        for (int i=0;i<numDataPoints;i++){
            dataSet.addLabel(i,labels[i]);
            if (labels[i]==1 && singleLabeldataSet.getRow(i).get(0)<0.1){
                dataSet.addLabel(i,2);
            }
            if (labels[i]==1 && singleLabeldataSet.getRow(i).get(1)<0.1){
                dataSet.addLabel(i,3);
            }
            for (int j=0;j<numFeatures;j++){
                double value = singleLabeldataSet.getRow(i).get(j);
                dataSet.setFeatureValue(i,j,value);
            }
        }

        List<String> extLabels = new ArrayList<>();
        extLabels.add("non_spam");
        extLabels.add("spam");
        extLabels.add("fake2");
        extLabels.add("fake3");
        LabelTranslator labelTranslator = new LabelTranslator(extLabels);
        dataSet.setLabelTranslator(labelTranslator);

        HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP,"/hmlgb/boosting.ser"));
        System.out.println(Accuracy.accuracy(boosting,dataSet));
        for (int i=0;i<numDataPoints;i++){
            Vector featureRow = dataSet.getRow(i);
            MultiLabel label = dataSet.getMultiLabels()[i];
            MultiLabel prediction = boosting.predict(featureRow);
//            System.out.println("label="+label);
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(0)));
//            System.out.println(boosting.calAssignmentScores(featureRow,assignments.get(1)));
//            System.out.println("prediction="+prediction);
//            if (!MultiLabel.equivalent(label,prediction)){
//                System.out.println(i);
//                System.out.println("label="+label);
//                System.out.println("prediction="+prediction);
//            }
        }

        MultiLabelPredictionAnalysis analysis = HMLGBInspector.analyzePrediction(boosting, dataSet, 0, 10);
        ObjectMapper mapper1 = new ObjectMapper();
        mapper1.writeValue(new File(TMP,"prediction_analysis.json"), analysis);

    }

    private static void test4() throws Exception{
        test4_build();
        test4_load();
    }

    /**
     * same as test3, the only difference is we now load data directly
     * @throws Exception
     */
    static void test4_build() throws Exception{
        MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS,"spam/4labels/train.trec"),
                DataSetType.ML_CLF_DENSE,true);

        List<MultiLabel> assignments = new ArrayList<>();
        assignments.add(new MultiLabel().addLabel(0));
        assignments.add(new MultiLabel().addLabel(1));
        assignments.add(new MultiLabel().addLabel(1).addLabel(2));
        assignments.add(new MultiLabel().addLabel(1).addLabel(3));
        assignments.add(new MultiLabel().addLabel(1).addLabel(2).addLabel(3));
        HMLGradientBoosting boosting = new HMLGradientBoosting(4,assignments);


        HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet)
                .numLeaves(100).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2)
                .dataSamplingRate(1).featureSamplingRate(1).build();
        System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));

        HMLGBTrainer trainer = new HMLGBTrainer(trainConfig,boosting);

        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        for (int round =0;round<10;round++){
            System.out.println("round="+round);
            trainer.iterate();
            System.out.println("accuracy="+Accuracy.accuracy(boosting,dataSet));

        }
        stopWatch.stop();
        System.out.println(stopWatch);
//        System.out.println(boosting);
        System.out.println("accuracy");
        System.out.println(Accuracy.accuracy(boosting,dataSet));
        System.out.println("macro-averaged:");
        System.out.println(new MacroAveragedMeasures(boosting,dataSet));
        boosting.serialize(new File(TMP,"/hmlgb/boosting.ser"));


    }

    static void test4_load() throws Exception{

        MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS,"spam/4labels/test.trec"),
                DataSetType.ML_CLF_DENSE,true);

        HMLGradientBoosting boosting = HMLGradientBoosting.deserialize(new File(TMP,"/hmlgb/boosting.ser"));
        System.out.println("accuracy="+Accuracy.accuracy(boosting,dataSet));
        System.out.println("macro-averaged:");
        System.out.println(new MacroAveragedMeasures(boosting,dataSet));

    }


    static void test5() throws Exception{
        MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
        MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
        List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
        HMLGradientBoosting boosting = new HMLGradientBoosting(dataSet.getNumClasses(),assignments);


        HMLGBConfig trainConfig = new HMLGBConfig.Builder(dataSet)
                .numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2)
                .dataSamplingRate(1).featureSamplingRate(1).build();

        HMLGBTrainer trainer = new HMLGBTrainer(trainConfig,boosting);
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();

        for (int round =0;round<100;round++){
            System.out.println("round="+round);
            trainer.iterate();
            System.out.println(stopWatch);
        }

        System.out.println("training accuracy="+ Accuracy.accuracy(boosting, dataSet));
        System.out.println("training overlap = "+ Overlap.overlap(boosting, dataSet));
        System.out.println("test accuracy="+ Accuracy.accuracy(boosting, testSet));
        System.out.println("test overlap = "+ Overlap.overlap(boosting,testSet));
    }

}