/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.smile.classification;

import static org.junit.Assert.assertEquals;

import matrix4j.matrix.Matrix;
import matrix4j.matrix.builders.CSRMatrixBuilder;
import matrix4j.matrix.dense.RowMajorDenseMatrix2d;
import matrix4j.vector.DenseVector;
import hivemall.smile.classification.DecisionTree.Node;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.tools.TreeExportUDF.Evaluator;
import hivemall.smile.tools.TreeExportUDF.OutputType;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.lang.ArrayUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.math.MathUtils;
import hivemall.utils.random.PRNG;
import hivemall.utils.random.RandomNumberGeneratorFactory;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.ArffParser;
import smile.data.parser.DelimitedTextParser;
import smile.math.Math;
import smile.validation.LOOCV;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.text.ParseException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Random;

import javax.annotation.Nonnull;

import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;
import org.roaringbitmap.RoaringBitmap;

public class DecisionTreeTest {
    private static final boolean DEBUG = false;

    @Test
    public void testWeather() throws IOException, ParseException {
        int responseIndex = 4;
        int numLeafs = 3;

        // dense matrix
        int error = run(
            "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff",
            responseIndex, numLeafs, true);
        assertEquals(5, error);

        // sparse matrix
        error = run(
            "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff",
            responseIndex, numLeafs, false);
        assertEquals(5, error);
    }

    @Test
    public void testIris() throws IOException, ParseException {
        int responseIndex = 4;
        int numLeafs = Integer.MAX_VALUE;
        int error = run(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
            responseIndex, numLeafs, true);
        assertEquals(8, error);

        // sparse
        error = run(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
            responseIndex, numLeafs, false);
        assertEquals(8, error);
    }

    @Test
    public void testIrisSparseDenseEquals() throws IOException, ParseException {
        int responseIndex = 4;
        int numLeafs = Integer.MAX_VALUE;
        runAndCompareSparseAndDense(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
            responseIndex, numLeafs);
    }

    @Test
    public void testIrisTracePredict() throws IOException, ParseException {
        int responseIndex = 4;
        int numLeafs = Integer.MAX_VALUE;
        runTracePredict(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
            responseIndex, numLeafs);
    }

    @Test
    public void testIrisDepth4() throws IOException, ParseException {
        int responseIndex = 4;
        int numLeafs = 4;
        int error = run(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
            responseIndex, numLeafs, true);
        assertEquals(7, error);

        // sparse 
        error = run(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
            responseIndex, numLeafs, false);
        assertEquals(7, error);
    }

    @Test
    public void testGraphvizOutputIris() throws IOException, ParseException, HiveException {
        String datasetUrl =
                "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff";
        int responseIndex = 4;
        int numLeafs = 4;
        boolean dense = true;
        String outputName = "class";
        String[] featureNames =
                new String[] {"sepallength", "sepalwidth", "petallength", "petalwidth"};
        String[] classNames = new String[] {"setosa", "versicolor", "virginica"};

        debugPrint(graphvizOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
            classNames, outputName));

        featureNames = null;
        classNames = null;
        outputName = null;
        debugPrint(graphvizOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
            classNames, outputName));
    }

    @Test
    public void testGraphvizOutputWeather() throws IOException, ParseException, HiveException {
        String datasetUrl =
                "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff";
        int responseIndex = 4;
        int numLeafs = 3;
        boolean dense = true;
        String[] featureNames = new String[] {"outlook", "temperature", "humidity", "windy"};
        String[] classNames = new String[] {"yes", "no"};
        String outputName = "play";

        debugPrint(graphvizOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
            classNames, outputName));

        featureNames = null;
        classNames = null;
        debugPrint(graphvizOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames,
            classNames, outputName));
    }

    private static String graphvizOutput(String datasetUrl, int responseIndex, int numLeafs,
            boolean dense, String[] featureNames, String[] classNames, String outputName)
            throws IOException, HiveException, ParseException {
        URL url = new URL(datasetUrl);
        InputStream is = new BufferedInputStream(url.openStream());

        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(responseIndex);

        AttributeDataset ds = arffParser.parse(is);
        double[][] x = ds.toArray(new double[ds.size()][]);
        int[] y = ds.toArray(new int[ds.size()]);

        RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
        DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, numLeafs,
            RandomNumberGeneratorFactory.createPRNG(31));

        Text model = new Text(Base91.encode(tree.serialize(true)));

        Evaluator eval = new Evaluator(OutputType.graphviz, outputName, false);
        Text exported = eval.export(model, featureNames, classNames);

        return exported.toString();
    }

    private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense)
            throws IOException, ParseException {
        URL url = new URL(datasetUrl);
        InputStream is = new BufferedInputStream(url.openStream());

        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(responseIndex);

        AttributeDataset ds = arffParser.parse(is);
        double[][] x = ds.toArray(new double[ds.size()][]);
        int[] y = ds.toArray(new int[ds.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        int error = 0;
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);

            RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
            DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), trainy, numLeafs,
                RandomNumberGeneratorFactory.createPRNG(i));
            if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
                error++;
            }
        }

        debugPrint("Decision Tree error = " + error);
        return error;
    }

    private static void runAndCompareSparseAndDense(String datasetUrl, int responseIndex,
            int numLeafs) throws IOException, ParseException {
        URL url = new URL(datasetUrl);
        InputStream is = new BufferedInputStream(url.openStream());

        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(responseIndex);

        AttributeDataset ds = arffParser.parse(is);
        double[][] x = ds.toArray(new double[ds.size()][]);
        int[] y = ds.toArray(new int[ds.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);

            RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(ds.attributes());
            DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), trainy, numLeafs,
                RandomNumberGeneratorFactory.createPRNG(i));
            DecisionTree stree = new DecisionTree(attrs, matrix(trainx, false), trainy, numLeafs,
                RandomNumberGeneratorFactory.createPRNG(i));
            Assert.assertEquals(dtree.predict(x[loocv.test[i]]), stree.predict(x[loocv.test[i]]));
            Assert.assertEquals(dtree.toString(), stree.toString());
        }
    }

    private static void runTracePredict(String datasetUrl, int responseIndex, int numLeafs)
            throws IOException, ParseException {
        URL url = new URL(datasetUrl);
        InputStream is = new BufferedInputStream(url.openStream());

        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(responseIndex);

        AttributeDataset ds = arffParser.parse(is);
        final Attribute[] attrs = ds.attributes();
        final Attribute targetAttr = ds.response();

        double[][] x = ds.toArray(new double[ds.size()][]);
        int[] y = ds.toArray(new int[ds.size()]);

        Random rnd = new Random(43L);
        int numTrain = (int) (x.length * 0.7);
        int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd);
        int[] cvTrain = Arrays.copyOf(index, numTrain);
        int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length);

        double[][] trainx = Math.slice(x, cvTrain);
        int[] trainy = Math.slice(y, cvTrain);
        double[][] testx = Math.slice(x, cvTest);

        DecisionTree tree = new DecisionTree(SmileExtUtils.convertAttributeTypes(attrs),
            matrix(trainx, false), trainy, numLeafs, RandomNumberGeneratorFactory.createPRNG(43L));

        final LinkedHashMap<String, Double> map = new LinkedHashMap<>();
        final StringBuilder buf = new StringBuilder();
        for (int i = 0; i < testx.length; i++) {
            final DenseVector test = new DenseVector(testx[i]);
            tree.predict(test, new PredictionHandler() {

                @Override
                public void visitBranch(Operator op, int splitFeatureIndex, double splitFeature,
                        double splitValue) {
                    buf.append(attrs[splitFeatureIndex].name);
                    buf.append(" [" + splitFeature + "] ");
                    buf.append(op);
                    buf.append(' ');
                    buf.append(splitValue);
                    buf.append('\n');

                    map.put(attrs[splitFeatureIndex].name + " [" + splitFeature + "] " + op,
                        splitValue);
                }

                @Override
                public void visitLeaf(int output, double[] posteriori) {
                    buf.append(targetAttr.toString(output));
                }
            });

            Assert.assertTrue(buf.length() > 0);
            Assert.assertFalse(map.isEmpty());

            StringUtils.clear(buf);
            map.clear();
        }

    }

    @Test
    public void testIrisSerializedObj() throws IOException, ParseException, HiveException {
        URL url = new URL(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
        InputStream is = new BufferedInputStream(url.openStream());

        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        AttributeDataset iris = arffParser.parse(is);
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);

            RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
            DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);

            byte[] b = tree.serialize(false);
            Node node = DecisionTree.deserialize(b, b.length, false);
            assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
        }
    }

    @Test
    public void testIrisSerializeObjCompressed() throws IOException, ParseException, HiveException {
        URL url = new URL(
            "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff");
        InputStream is = new BufferedInputStream(url.openStream());

        ArffParser arffParser = new ArffParser();
        arffParser.setResponseIndex(4);
        AttributeDataset iris = arffParser.parse(is);
        double[][] x = iris.toArray(new double[iris.size()][]);
        int[] y = iris.toArray(new int[iris.size()]);

        int n = x.length;
        LOOCV loocv = new LOOCV(n);
        for (int i = 0; i < n; i++) {
            double[][] trainx = Math.slice(x, loocv.train[i]);
            int[] trainy = Math.slice(y, loocv.train[i]);

            RoaringBitmap attrs = SmileExtUtils.convertAttributeTypes(iris.attributes());
            DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4);

            byte[] b1 = tree.serialize(true);
            byte[] b2 = tree.serialize(false);
            Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length,
                b1.length < b2.length);
            Node node = DecisionTree.deserialize(b1, b1.length, true);
            assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]]));
        }
    }

    @Test
    public void testTitanicPruning() throws IOException, ParseException {
        String datasetUrl =
                "https://gist.githubusercontent.com/myui/7cd82c443db84ba7e7add1523d0247a9/raw/f2d3e3051b0292577e8c01a1759edabaa95c5781/titanic_train.tsv";

        URL url = new URL(datasetUrl);
        InputStream is = new BufferedInputStream(url.openStream());

        DelimitedTextParser parser = new DelimitedTextParser();
        parser.setColumnNames(true);
        parser.setDelimiter(",");
        parser.setResponseIndex(new NominalAttribute("survived"), 0);

        AttributeDataset train = parser.parse("titanic train", is);
        double[][] x_ = train.toArray(new double[train.size()][]);
        int[] y = train.toArray(new int[train.size()]);

        // pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked
        // C,C,C,Q,Q,Q,C,Q,C,C
        RoaringBitmap nominalAttrs = new RoaringBitmap();
        nominalAttrs.add(0);
        nominalAttrs.add(1);
        nominalAttrs.add(2);
        nominalAttrs.add(6);
        nominalAttrs.add(8);
        nominalAttrs.add(9);

        int columns = x_[0].length;
        Matrix x = new RowMajorDenseMatrix2d(x_, columns);
        int numVars = (int) Math.ceil(Math.sqrt(columns));
        int maxDepth = Integer.MAX_VALUE;
        int maxLeafs = Integer.MAX_VALUE;
        int minSplits = 2;
        int minLeafSize = 1;
        int[] samples = null;
        PRNG rand = RandomNumberGeneratorFactory.createPRNG(43L);

        final String[] featureNames = new String[] {"pclass", "name", "sex", "age", "sibsp",
                "parch", "ticket", "fare", "cabin", "embarked"};
        final String[] classNames = new String[] {"yes", "no"};
        DecisionTree tree = new DecisionTree(nominalAttrs, x, y, numVars, maxDepth, maxLeafs,
            minSplits, minLeafSize, samples, SplitRule.GINI, rand) {
            @Override
            public String toString() {
                return predictJsCodegen(featureNames, classNames);
            }
        };
        tree.toString();
    }

    @Nonnull
    private static Matrix matrix(@Nonnull final double[][] x, boolean dense) {
        if (dense) {
            return new RowMajorDenseMatrix2d(x, x[0].length);
        } else {
            int numRows = x.length;
            CSRMatrixBuilder builder = new CSRMatrixBuilder(1024);
            for (int i = 0; i < numRows; i++) {
                builder.nextRow(x[i]);
            }
            return builder.buildMatrix();
        }
    }

    private static void debugPrint(String msg) {
        if (DEBUG) {
            System.out.println(msg);
        }
    }

}