/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.classifier;

import static hivemall.utils.hadoop.HiveUtils.lazyInteger;
import static hivemall.utils.hadoop.HiveUtils.lazyLong;
import static hivemall.utils.hadoop.HiveUtils.lazyString;

import hivemall.TestUtils;
import hivemall.utils.math.MathUtils;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringTokenizer;
import java.util.zip.GZIPInputStream;

import javax.annotation.Nonnull;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.Collector;
import org.apache.hadoop.hive.serde2.lazy.LazyInteger;
import org.apache.hadoop.hive.serde2.lazy.LazyLong;
import org.apache.hadoop.hive.serde2.lazy.LazyString;
import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyStringObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.junit.Assert;
import org.junit.Test;

public class GeneralClassifierUDTFTest {
    private static final boolean DEBUG = false;

    @Test(expected = UDFArgumentException.class)
    public void testUnsupportedOptimizer() throws Exception {
        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-opt UnsupportedOpt");

        udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
    }

    @Test
    public void testInspectOptimizerOptions() throws Exception {
        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            "-opt adam -reg l1 -inspect_opts");

        try {
            udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
            Assert.fail("should not come here");
        } catch (UDFArgumentException e) {
            Assert.assertTrue(e.getMessage().contains("adam"));
        }
    }

    @Test(expected = UDFArgumentException.class)
    public void testUnsupportedLossFunction() throws Exception {
        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-loss UnsupportedLoss");

        udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
    }

    @Test(expected = UDFArgumentException.class)
    public void testUnsupportedRegularization() throws Exception {
        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reg UnsupportedReg");

        udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});
    }

    @Test
    public void testNoOptions() throws Exception {
        List<String> x = Arrays.asList("1:-2", "2:-1");
        int y = 0;

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);

        udtf.initialize(new ObjectInspector[] {stringListOI, intOI});

        udtf.process(new Object[] {x, y});

        udtf.finalizeTraining();

        float score = udtf.predict(udtf.parseFeatures(x));
        int predicted = score > 0.f ? 1 : 0;
        Assert.assertTrue(y == predicted);
    }

    private <T> void testFeature(@Nonnull List<T> x, @Nonnull ObjectInspector featureOI,
            @Nonnull Class<T> featureClass, @Nonnull Class<?> modelFeatureClass) throws Exception {
        int y = 0;

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector valueOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ListObjectInspector featureListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(featureOI);

        udtf.initialize(new ObjectInspector[] {featureListOI, valueOI});

        final List<Object> modelFeatures = new ArrayList<Object>();
        udtf.setCollector(new Collector() {
            @Override
            public void collect(Object input) throws HiveException {
                Object[] forwardMapObj = (Object[]) input;
                modelFeatures.add(forwardMapObj[0]);
            }
        });

        udtf.process(new Object[] {x, y});

        udtf.close();

        Assert.assertFalse(modelFeatures.isEmpty());
        for (Object modelFeature : modelFeatures) {
            Assert.assertEquals("All model features must have same type", modelFeatureClass,
                modelFeature.getClass());
        }
    }

    @Test
    public void testStringFeature() throws Exception {
        List<String> x = Arrays.asList("1:-2", "2:-1");
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        testFeature(x, featureOI, String.class, String.class);
    }

    @Test(expected = IllegalArgumentException.class)
    public void testIllegalStringFeature() throws Exception {
        List<String> x = Arrays.asList("1:-2jjjjj", "2:-1");
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        testFeature(x, featureOI, String.class, String.class);
    }

    @Test
    public void testLazyStringFeature() throws Exception {
        LazyStringObjectInspector oi =
                LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector(false, (byte) 0);
        List<LazyString> x = Arrays.asList(lazyString("テスト:-2", oi), lazyString("漢字:-333.0", oi),
            lazyString("test:-1"));
        testFeature(x, oi, LazyString.class, String.class);
    }

    @Test
    public void testTextFeature() throws Exception {
        List<Text> x = Arrays.asList(new Text("1:-2"), new Text("2:-1"));
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableStringObjectInspector;
        testFeature(x, featureOI, Text.class, String.class);
    }

    @Test
    public void testIntegerFeature() throws Exception {
        List<Integer> x = Arrays.asList(111, 222);
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        testFeature(x, featureOI, Integer.class, Integer.class);
    }

    @Test
    public void testLazyIntegerFeature() throws Exception {
        List<LazyInteger> x = Arrays.asList(lazyInteger(111), lazyInteger(222));
        ObjectInspector featureOI = LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR;
        testFeature(x, featureOI, LazyInteger.class, Integer.class);
    }

    @Test
    public void testWritableIntFeature() throws Exception {
        List<IntWritable> x = Arrays.asList(new IntWritable(111), new IntWritable(222));
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableIntObjectInspector;
        testFeature(x, featureOI, IntWritable.class, Integer.class);
    }

    @Test
    public void testLongFeature() throws Exception {
        List<Long> x = Arrays.asList(111L, 222L);
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
        testFeature(x, featureOI, Long.class, Long.class);
    }

    @Test
    public void testLazyLongFeature() throws Exception {
        List<LazyLong> x = Arrays.asList(lazyLong(111), lazyLong(222));
        ObjectInspector featureOI = LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR;
        testFeature(x, featureOI, LazyLong.class, Long.class);
    }

    @Test
    public void testWritableLongFeature() throws Exception {
        List<LongWritable> x = Arrays.asList(new LongWritable(111L), new LongWritable(222L));
        ObjectInspector featureOI = PrimitiveObjectInspectorFactory.writableLongObjectInspector;
        testFeature(x, featureOI, LongWritable.class, Long.class);
    }

    private void run(@Nonnull String options) throws Exception {
        println(options);

        ArrayList<List<String>> samplesList = new ArrayList<List<String>>();
        samplesList.add(Arrays.asList("1:-2", "2:-1"));
        samplesList.add(Arrays.asList("1:-1", "2:-1"));
        samplesList.add(Arrays.asList("1:-1", "2:-2"));
        samplesList.add(Arrays.asList("1:1", "2:1"));
        samplesList.add(Arrays.asList("1:1", "2:2"));
        samplesList.add(Arrays.asList("1:2", "2:1"));

        int[] labels = new int[] {0, 0, 0, 1, 1, 1};

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});

        for (int i = 0, size = samplesList.size(); i < size; i++) {
            udtf.process(new Object[] {samplesList.get(i), labels[i]});
        }

        udtf.finalizeTraining();

        double cumLoss = udtf.getCumulativeLoss();
        println("Cumulative loss: " + cumLoss);
        double normalizedLoss = cumLoss / samplesList.size();
        Assert.assertTrue(
            "cumLoss: " + cumLoss + ", normalizedLoss: " + normalizedLoss + "\noptions: " + options,
            normalizedLoss < 0.5d);

        int numTests = 0;
        int numCorrect = 0;

        for (int i = 0, size = samplesList.size(); i < size; i++) {
            int label = labels[i];

            float score = udtf.predict(udtf.parseFeatures(samplesList.get(i)));
            int predicted = score > 0.f ? 1 : 0;

            println("Score: " + score + ", Predicted: " + predicted + ", Actual: " + label);

            if (predicted == label) {
                ++numCorrect;
            }
            ++numTests;
        }

        float accuracy = numCorrect / (float) numTests;
        println("Accuracy: " + accuracy);
        Assert.assertTrue(accuracy == 1.f);
    }

    @Test
    public void test() throws Exception {
        String[] optimizers = new String[] {"SGD", "AdaDelta", "AdaGrad", "Adam"};
        String[] regularizations = new String[] {"NO", "L1", "L2", "ElasticNet", "RDA"};
        String[] lossFunctions = new String[] {"HingeLoss", "LogLoss", "SquaredHingeLoss",
                "ModifiedHuberLoss", "SquaredLoss", "QuantileLoss", "EpsilonInsensitiveLoss",
                "SquaredEpsilonInsensitiveLoss", "HuberLoss"};

        for (String opt : optimizers) {
            for (String reg : regularizations) {
                if (reg == "RDA" && opt != "AdaGrad") {
                    continue;
                }

                for (String loss : lossFunctions) {
                    String options = "-opt " + opt + " -reg " + reg + " -loss " + loss
                            + " -cv_rate 0.001 -iter 512";

                    // sparse
                    run(options);

                    if (opt != "AdaGrad") {
                        options += " -mini_batch 2";
                        run(options);
                    }

                    // dense
                    options += " -dense";
                    run(options);
                }
            }
        }
    }

    @SuppressWarnings("unchecked")
    @Test
    public void testNews20() throws IOException, ParseException, HiveException {
        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();
        ObjectInspector intOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        ObjectInspector stringOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        ListObjectInspector stringListOI =
                ObjectInspectorFactory.getStandardListObjectInspector(stringOI);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            "-opt SGD -loss logloss -reg L2 -lambda 0.1 -cv_rate 0.005");

        udtf.initialize(new ObjectInspector[] {stringListOI, intOI, params});

        BufferedReader news20 = readFile("news20-small.binary.gz");
        ArrayList<Integer> labels = new ArrayList<Integer>();
        ArrayList<String> words = new ArrayList<String>();
        ArrayList<ArrayList<String>> wordsList = new ArrayList<ArrayList<String>>();
        String line = news20.readLine();
        while (line != null) {
            StringTokenizer tokens = new StringTokenizer(line, " ");
            int label = Integer.parseInt(tokens.nextToken());
            while (tokens.hasMoreTokens()) {
                words.add(tokens.nextToken());
            }
            Assert.assertFalse(words.isEmpty());
            udtf.process(new Object[] {words, label});

            labels.add(label);
            wordsList.add((ArrayList<String>) words.clone());

            words.clear();
            line = news20.readLine();
        }
        news20.close();

        // perform SGD iterations
        udtf.finalizeTraining();

        int numTests = 0;
        int numCorrect = 0;

        for (int i = 0, size = wordsList.size(); i < size; i++) {
            words = wordsList.get(i);
            int label = labels.get(i);

            float score = udtf.predict(udtf.parseFeatures(words));
            int predicted = MathUtils.sign(score);

            println("Score: " + score + ", Predicted: " + predicted + ", Actual: " + label);

            if (predicted == label) {
                ++numCorrect;
            }
            ++numTests;
        }

        float accuracy = numCorrect / (float) numTests;
        println("Accuracy: " + accuracy);
        Assert.assertTrue(accuracy > 0.8f);
    }

    @Test
    public void testSerialization() throws HiveException {
        TestUtils.testGenericUDTFSerialization(GeneralClassifierUDTF.class,
            new ObjectInspector[] {
                    ObjectInspectorFactory.getStandardListObjectInspector(
                        PrimitiveObjectInspectorFactory.javaStringObjectInspector),
                    PrimitiveObjectInspectorFactory.javaIntObjectInspector},
            new Object[][] {{Arrays.asList("1:-2", "2:-1"), 0}});
    }

    @Test
    public void testSGD() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt sgd -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1300: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1300);
    }

    @Test
    public void testMomentum() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt momentum -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1200: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1200);
    }

    @Test
    public void testNesterov() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt nesterov -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1100: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1100);
    }

    @Test
    public void testAdagradL1() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt adagrad -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1400: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1400);
    }

    @Test
    public void testRMSprop() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt rmsprop -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1300: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1300);
    }

    @Test
    public void testRMSpropGraves() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt RMSpropGraves -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1200: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1200);
    }


    @Test
    public void testAdaDeltaL1() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt adadelta -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1500: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1500);
    }

    @Test
    public void testAdam() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt Adam -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 800: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 800);
    }

    @Test
    public void testNadam() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt Nadam -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 800: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 800);
    }

    @Test
    public void testEve() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt Eve -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 800: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 800);
    }

    @Test
    public void testAdamAmsgrad() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt Adam -amsgrad -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1200: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1200);
    }

    @Test
    public void testEveAmsgrad() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt Eve -amsgrad -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 1200: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 1200);
    }


    @Test
    public void testAdamHD() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt AdamHD -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 800: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 800);
    }

    @Test
    public void testAdamDecay() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-loss logloss -opt Adam -decay 0.001 -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 900: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 900);
    }


    @Test
    public void testAdamInvScaleEta() throws IOException, HiveException {
        String filePath = "adam_test_10000.tsv.gz";
        String options =
                "-eta inv -eta0 0.1 -loss logloss -opt Adam -reg l1 -lambda 0.0001 -iter 10 -mini_batch 1 -cv_rate 0.00005";

        GeneralClassifierUDTF udtf = new GeneralClassifierUDTF();

        ListObjectInspector stringListOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        ObjectInspector params = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector, options);

        udtf.initialize(new ObjectInspector[] {stringListOI,
                PrimitiveObjectInspectorFactory.javaIntObjectInspector, params});

        BufferedReader reader = readFile(filePath);
        for (String line = reader.readLine(); line != null; line = reader.readLine()) {
            StringTokenizer tokenizer = new StringTokenizer(line, " ");

            String featureLine = tokenizer.nextToken();
            List<String> X = Arrays.asList(featureLine.split(","));

            String labelLine = tokenizer.nextToken();
            Integer y = Integer.valueOf(labelLine);

            udtf.process(new Object[] {X, y});
        }

        udtf.finalizeTraining();

        Assert.assertTrue(
            "CumulativeLoss is expected to be less than 900: " + udtf.getCumulativeLoss(),
            udtf.getCumulativeLoss() < 900);
    }

    private static void println(String msg) {
        if (DEBUG) {
            System.out.println(msg);
        }
    }

    @Nonnull
    private static BufferedReader readFile(@Nonnull String fileName) throws IOException {
        InputStream is = GeneralClassifierUDTFTest.class.getResourceAsStream(fileName);
        if (fileName.endsWith(".gz")) {
            is = new GZIPInputStream(is);
        }
        return new BufferedReader(new InputStreamReader(is));
    }
}