/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.smile.regression;

import hivemall.TestUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.hashing.MurmurHash3;
import hivemall.utils.lang.mutable.MutableInt;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.Collector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;

import javax.annotation.Nonnull;
import java.io.IOException;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;

public class RandomForestRegressionUDTFTest {

    @Test
    public void testDense() throws IOException, ParseException, HiveException {
        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
        udtf.initialize(new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

        final List<Double> xi = new ArrayList<Double>(x[0].length);
        for (int i = 0; i < x.length; i++) {
            for (int j = 0; j < x[i].length; j++) {
                xi.add(j, x[i][j]);
            }
            udtf.process(new Object[] {xi, y[i]});
            xi.clear();
        }

        final MutableInt count = new MutableInt(0);
        Collector collector = new Collector() {
            public void collect(Object input) throws HiveException {
                count.addValue(1);
            }
        };

        udtf.setCollector(collector);
        udtf.close();

        Assert.assertEquals(49, count.getValue());
    }

    @Test
    public void testSparse() throws IOException, ParseException, HiveException {
        String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};

        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49");
        udtf.initialize(new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

        final List<String> xi = new ArrayList<String>(x[0].length);
        for (int i = 0; i < x.length; i++) {
            double[] row = x[i];
            for (int j = 0; j < row.length; j++) {
                xi.add(mhash(featureNames[j]) + ":" + row[j]);
            }
            udtf.process(new Object[] {xi, y[i]});
            xi.clear();
        }

        final MutableInt count = new MutableInt(0);
        Collector collector = new Collector() {
            public void collect(Object input) throws HiveException {
                count.addValue(1);
            }
        };

        udtf.setCollector(collector);
        udtf.close();

        Assert.assertEquals(49, count.getValue());
    }

    @Test
    public void testSparseDenseEquals() throws IOException, ParseException, HiveException {
        RegressionTree.Node denseNode = getRegressionTreeFromDenseInput();
        RegressionTree.Node sparseNode = getRegressionTreeFromSparseInput();

        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

        int diff = 0;
        for (int i = 0; i < x.length; i++) {
            if (denseNode.predict(x[i]) != sparseNode.predict(x[i])) {
                diff++;
            }
        }

        Assert.assertTrue("large diff " + diff + " between two predictions", diff < 10);
    }

    private static RegressionTree.Node getRegressionTreeFromDenseInput()
            throws IOException, ParseException, HiveException {
        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
        udtf.initialize(new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector),
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

        final List<Double> xi = new ArrayList<Double>(x[0].length);
        for (int i = 0; i < x.length; i++) {
            for (int j = 0; j < x[i].length; j++) {
                xi.add(j, x[i][j]);
            }
            udtf.process(new Object[] {xi, y[i]});
            xi.clear();
        }

        final Text[] placeholder = new Text[1];
        Collector collector = new Collector() {
            public void collect(Object input) throws HiveException {
                Object[] forward = (Object[]) input;
                placeholder[0] = (Text) forward[2];
            }
        };

        udtf.setCollector(collector);
        udtf.close();

        Text modelTxt = placeholder[0];
        Assert.assertNotNull(modelTxt);

        byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
        RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
        return node;
    }

    private static RegressionTree.Node getRegressionTreeFromSparseInput()
            throws IOException, ParseException, HiveException {
        String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};

        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

        RandomForestRegressionUDTF udtf = new RandomForestRegressionUDTF();
        ObjectInspector param = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 1 -seed 71");
        udtf.initialize(new ObjectInspector[] {
                ObjectInspectorFactory.getStandardListObjectInspector(
                    PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, param});

        final List<String> xi = new ArrayList<String>(x[0].length);
        for (int i = 0; i < x.length; i++) {
            final double[] row = x[i];
            for (int j = 0; j < row.length; j++) {
                xi.add(mhash(featureNames[j]) + ":" + row[j]);
            }
            udtf.process(new Object[] {xi, y[i]});
            xi.clear();
        }

        final Text[] placeholder = new Text[1];
        Collector collector = new Collector() {
            public void collect(Object input) throws HiveException {
                Object[] forward = (Object[]) input;
                placeholder[0] = (Text) forward[2];
            }
        };

        udtf.setCollector(collector);
        udtf.close();

        Text modelTxt = placeholder[0];
        Assert.assertNotNull(modelTxt);

        byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength());
        RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true);
        return node;
    }

    @Test
    public void testSerialization() throws HiveException, IOException, ParseException {
        String[] featureNames = {"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9"};

        double[][] x = {{234.289, 235.6, 159.0, 107.608, 1947, 60.323},
                {259.426, 232.5, 145.6, 108.632, 1948, 61.122},
                {258.054, 368.2, 161.6, 109.773, 1949, 60.171},
                {284.599, 335.1, 165.0, 110.929, 1950, 61.187},
                {328.975, 209.9, 309.9, 112.075, 1951, 63.221},
                {346.999, 193.2, 359.4, 113.270, 1952, 63.639},
                {365.385, 187.0, 354.7, 115.094, 1953, 64.989},
                {363.112, 357.8, 335.0, 116.219, 1954, 63.761},
                {397.469, 290.4, 304.8, 117.388, 1955, 66.019}};

        double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2};

        final Object[][] rows = new Object[x.length][2];
        for (int i = 0; i < x.length; i++) {
            double[] row = x[i];
            final List<String> xi = new ArrayList<String>(x[0].length);
            for (int j = 0; j < row.length; j++) {
                xi.add(mhash(featureNames[j]) + ":" + row[j]);
            }
            rows[i][0] = xi;
            rows[i][1] = y[i];
        }

        TestUtils.testGenericUDTFSerialization(RandomForestRegressionUDTF.class,
            new ObjectInspector[] {
                    ObjectInspectorFactory.getStandardListObjectInspector(
                        PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                    PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
                    ObjectInspectorUtils.getConstantObjectInspector(
                        PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-trees 49")},
            rows);
    }

    private static int mhash(@Nonnull final String word) {
        final int n = 16777217; // 2^24
        int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % n;
        if (r < 0) {
            r += n;
        }
        return r + 1;
    }

}