/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.factorization.fm;

import hivemall.factorization.fm.FMArrayModel;
import hivemall.factorization.fm.FactorizationMachineModel;
import hivemall.factorization.fm.FactorizationMachineUDTF;
import hivemall.factorization.fm.IntFeature;

import java.util.ArrayList;
import java.util.Random;

import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.junit.Assert;
import org.junit.Test;

public class ArrayModelTest {
    private static final boolean DEBUG_PRINT = false;

    private static void println(String msg) {
        if (DEBUG_PRINT)
            System.out.println(msg);
    }

    @Test
    public void testClassification() throws HiveException {
        final int ROW = 10, COL = 40;

        FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
        ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            "-adareg -factors 20 -classification -seed 31 -iters 10 -int_feature -p " + COL);
        udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
        FactorizationMachineModel model = udtf.initModel(udtf._params);
        Assert.assertTrue("Actual class: " + model.getClass().getName(),
            model instanceof FMArrayModel);

        float accuracy = 0.f;
        final Random rnd = new Random(201L);
        for (int numberOfIteration = 0; numberOfIteration < 10000; numberOfIteration++) {
            ArrayList<IntFeature[]> fArrayList = new ArrayList<IntFeature[]>();
            ArrayList<Double> ans = new ArrayList<Double>();
            for (int i = 0; i < ROW; i++) {
                ArrayList<IntFeature> feature = new ArrayList<IntFeature>();
                for (int j = 1; j <= COL; j++) {
                    if (i < (0.5f * ROW)) {
                        if (j == 1) {
                            feature.add(new IntFeature(j, 1.d));
                        } else if (j < 0.5 * COL) {
                            if (rnd.nextFloat() < 0.2f) {
                                feature.add(new IntFeature(j, rnd.nextDouble()));
                            }
                        }
                    } else {
                        if (j > 0.5f * COL) {
                            if (rnd.nextFloat() < 0.2f) {
                                feature.add(new IntFeature(j, rnd.nextDouble()));
                            }
                        }
                    }
                }
                IntFeature[] x = new IntFeature[feature.size()];
                feature.toArray(x);
                fArrayList.add(x);

                final double y;
                if (i < ROW * 0.5f) {
                    y = -1.0d;
                } else {
                    y = 1.0d;
                }
                ans.add(y);

                udtf.process(new Object[] {toStringArray(x), y});
            }
            int bingo = 0;
            int total = fArrayList.size();
            for (int i = 0; i < total; i++) {
                double tmpAns = ans.get(i);
                if (tmpAns < 0) {
                    tmpAns = 0;
                } else {
                    tmpAns = 1;
                }
                double p = model.predict(fArrayList.get(i));
                int predicted = p > 0.5 ? 1 : 0;
                if (predicted == tmpAns) {
                    bingo++;
                }
            }
            accuracy = bingo / (float) total;
            println("Accuracy = " + accuracy);
        }
        udtf.runTrainingIteration(10);
        Assert.assertTrue("Expected accuracy greater than 0.95f: " + accuracy, accuracy > 0.95f);
    }

    @Test
    public void testRegression() throws HiveException {
        final int ROW = 1000, COL = 80;

        FactorizationMachineUDTF udtf = new FactorizationMachineUDTF();
        ListObjectInspector xOI = ObjectInspectorFactory.getStandardListObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
            PrimitiveObjectInspectorFactory.javaStringObjectInspector,
            "-factors 20 -seed 31 -eta 0.001 -lambda0 0.1 -sigma 0.1 -int_feature -p " + COL);
        udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
        FactorizationMachineModel model = udtf.initModel(udtf._params);
        Assert.assertTrue("Actual class: " + model.getClass().getName(),
            model instanceof FMArrayModel);

        double diff = 0.d;
        final Random rnd = new Random(201L);
        for (int numberOfIteration = 0; numberOfIteration < 100; numberOfIteration++) {
            ArrayList<IntFeature[]> fArrayList = new ArrayList<IntFeature[]>();
            ArrayList<Double> ans = new ArrayList<Double>();
            for (int i = 0; i < ROW; i++) {
                ArrayList<IntFeature> feature = new ArrayList<IntFeature>();
                for (int j = 1; j <= COL; j++) {
                    if (i < (0.5f * ROW)) {
                        if (j == 1) {
                            feature.add(new IntFeature(j, 1.d));
                        } else if (j < 0.5 * COL) {
                            if (rnd.nextFloat() < 0.2f) {
                                feature.add(new IntFeature(j, rnd.nextDouble()));
                            }
                        }
                    } else {
                        if (j > (0.5f * COL)) {
                            if (rnd.nextFloat() < 0.2f) {
                                feature.add(new IntFeature(j, rnd.nextDouble()));
                            }
                        }
                    }
                }
                IntFeature[] x = new IntFeature[feature.size()];
                feature.toArray(x);
                fArrayList.add(x);

                final double y;
                if (i < ROW * 0.5f) {
                    y = 0.1d;
                } else {
                    y = 0.4d;
                }
                ans.add(y);

                udtf.process(new Object[] {toStringArray(x), y});
            }

            diff = 0.d;
            for (int i = 0; i < fArrayList.size(); i++) {
                double predicted = model.predict(fArrayList.get(i));
                double actual = ans.get(i);
                double tmpDiff = predicted - actual;
                diff += tmpDiff * tmpDiff;
            }
            println("diff = " + diff);
        }
        Assert.assertTrue("diff = " + diff, diff < 5.d);
    }

    private static String[] toStringArray(IntFeature[] x) {
        final int size = x.length;
        final String[] ret = new String[size];
        for (int i = 0; i < size; i++) {
            ret[i] = x[i].toString();
        }
        return ret;
    }
}