/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.classifier.multiclass;

import static hivemall.HivemallConstants.BIGINT_TYPE_NAME;
import static hivemall.HivemallConstants.INT_TYPE_NAME;
import static hivemall.HivemallConstants.STRING_TYPE_NAME;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
import hivemall.LearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.Margin;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.Text;

public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF {
    private static final Log logger = LogFactory.getLog(MulticlassOnlineClassifierUDTF.class);

    private ListObjectInspector featureListOI;
    private boolean parseFeature;
    private PrimitiveObjectInspector labelInputOI;

    protected Map<Object, PredictionModel> label2model;
    protected int count;

    public MulticlassOnlineClassifierUDTF() {
        this(false);
    }

    public MulticlassOnlineClassifierUDTF(boolean enableNewModel) {
        super(enableNewModel);
    }

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 2) {
            throw new UDFArgumentException(getClass().getSimpleName()
                    + " takes 2 arguments: List<Int|BigInt|Text> features, {Int|BitInt|Text} label [, constant text options]");
        }
        PrimitiveObjectInspector featureInputOI = processFeaturesOI(argOIs[0]);
        this.labelInputOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]);
        String labelTypeName = labelInputOI.getTypeName();
        if (!STRING_TYPE_NAME.equals(labelTypeName) && !INT_TYPE_NAME.equals(labelTypeName)
                && !BIGINT_TYPE_NAME.equals(labelTypeName)) {
            throw new UDFArgumentTypeException(0,
                "label must be a type [Int|BigInt|Text]: " + labelTypeName);
        }

        processOptions(argOIs);

        this.label2model = new HashMap<Object, PredictionModel>(64);
        this.count = 0;

        return getReturnOI(labelInputOI, getFeatureOutputOI(featureInputOI));
    }

    @Override
    protected int getInitialModelSize() {
        return 8192;
    }

    @Nonnull
    protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector arg)
            throws UDFArgumentException {
        this.featureListOI = (ListObjectInspector) arg;
        ObjectInspector featureRawOI = featureListOI.getListElementObjectInspector();
        String keyTypeName = featureRawOI.getTypeName();
        if (!STRING_TYPE_NAME.equals(keyTypeName) && !INT_TYPE_NAME.equals(keyTypeName)
                && !BIGINT_TYPE_NAME.equals(keyTypeName)) {
            throw new UDFArgumentTypeException(0,
                "1st argument must be Map of key type [Int|BitInt|Text]: " + keyTypeName);
        }
        this.parseFeature = STRING_TYPE_NAME.equals(keyTypeName);
        return HiveUtils.asPrimitiveObjectInspector(featureRawOI);
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector labelRawOI,
            @Nonnull ObjectInspector featureOutputOI) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

        fieldNames.add("label");
        ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI);
        fieldOIs.add(labelOI);
        fieldNames.add("feature");
        fieldOIs.add(featureOutputOI);
        fieldNames.add("weight");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (useCovariance()) {
            fieldNames.add("covar");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public void process(Object[] args) throws HiveException {
        List<?> features = (List<?>) featureListOI.getList(args[0]);
        FeatureValue[] featureVector = parseFeatures(features);
        if (featureVector == null) {
            return;
        }
        Object label = ObjectInspectorUtils.copyToStandardObject(args[1], labelInputOI);
        if (label == null) {
            throw new UDFArgumentException("label value must not be NULL");
        }

        count++;
        train(featureVector, label);
    }

    @Nullable
    protected final FeatureValue[] parseFeatures(@Nonnull final List<?> features) {
        final int size = features.size();
        if (size == 0) {
            return null;
        }

        final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
        final FeatureValue[] featureVector = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object f = features.get(i);
            if (f == null) {
                continue;
            }
            final FeatureValue fv;
            if (parseFeature) {
                fv = FeatureValue.parse(f);
            } else {
                Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
                fv = new FeatureValue(k, 1.f);
            }
            featureVector[i] = fv;
        }
        return featureVector;
    }

    protected abstract void train(@Nonnull final FeatureValue[] features,
            @Nonnull final Object actual_label);

    protected final PredictionResult classify(@Nonnull final FeatureValue[] features) {
        float maxScore = Float.MIN_VALUE;
        Object maxScoredLabel = null;

        for (Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
            Object label = label2map.getKey();
            PredictionModel model = label2map.getValue();
            float score = calcScore(model, features);
            if (maxScoredLabel == null || score > maxScore) {
                maxScore = score;
                maxScoredLabel = label;
            }
        }

        return new PredictionResult(maxScoredLabel, maxScore);
    }

    protected Margin getMargin(@Nonnull final FeatureValue[] features, final Object actual_label) {
        float correctScore = 0.f;
        Object maxAnotherLabel = null;
        float maxAnotherScore = 0.f;

        for (Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
            Object label = label2map.getKey();
            PredictionModel model = label2map.getValue();
            float score = calcScore(model, features);
            if (label.equals(actual_label)) {
                correctScore = score;
            } else {
                if (maxAnotherLabel == null || score > maxAnotherScore) {
                    maxAnotherLabel = label;
                    maxAnotherScore = score;
                }
            }
        }
        return new Margin(correctScore, maxAnotherLabel, maxAnotherScore);
    }

    protected Margin getMarginAndVariance(@Nonnull final FeatureValue[] features,
            final Object actual_label) {
        return getMarginAndVariance(features, actual_label, false);
    }

    protected Margin getMarginAndVariance(@Nonnull final FeatureValue[] features,
            final Object actual_label, boolean nonZeroVariance) {
        float correctScore = 0.f;
        float correctVariance = 0.f;
        Object maxAnotherLabel = null;
        float maxAnotherScore = 0.f;
        float maxAnotherVariance = 0.f;

        if (nonZeroVariance && label2model.isEmpty()) {// for initial call
            float var = 2.f * calcVariance(features);
            return new Margin(correctScore, maxAnotherLabel, maxAnotherScore).variance(var);
        }

        for (Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
            Object label = label2map.getKey();
            PredictionModel model = label2map.getValue();
            PredictionResult predicted = calcScoreAndVariance(model, features);
            float score = predicted.getScore();

            if (label.equals(actual_label)) {
                correctScore = score;
                correctVariance = predicted.getVariance();
            } else {
                if (maxAnotherLabel == null || score > maxAnotherScore) {
                    maxAnotherLabel = label;
                    maxAnotherScore = score;
                    maxAnotherVariance = predicted.getVariance();
                }
            }
        }

        float var = correctVariance + maxAnotherVariance;
        return new Margin(correctScore, maxAnotherLabel, maxAnotherScore).variance(var);
    }

    protected final float squaredNorm(@Nonnull final FeatureValue[] features) {
        float squared_norm = 0.f;
        for (FeatureValue f : features) {// a += w[i] * x[i]
            if (f == null) {
                continue;
            }
            final float v = f.getValueAsFloat();
            squared_norm += (v * v);
        }
        return squared_norm;
    }

    protected final float calcScore(@Nonnull final PredictionModel model,
            @Nonnull final FeatureValue[] features) {
        float score = 0.f;
        for (FeatureValue f : features) {// a += w[i] * x[i]
            if (f == null) {
                continue;
            }
            final Object k = f.getFeature();
            final float v = f.getValueAsFloat();

            float old_w = model.getWeight(k);
            if (old_w != 0f) {
                score += (old_w * v);
            }
        }
        return score;
    }

    protected final float calcVariance(@Nonnull final FeatureValue[] features) {
        float variance = 0.f;
        for (FeatureValue f : features) {// a += w[i] * x[i]
            if (f == null) {
                continue;
            }
            float v = f.getValueAsFloat();
            variance += v * v;
        }
        return variance;
    }

    protected final PredictionResult calcScoreAndVariance(@Nonnull final PredictionModel model,
            @Nonnull final FeatureValue[] features) {
        float score = 0.f;
        float variance = 0.f;

        for (FeatureValue f : features) {// a += w[i] * x[i]
            if (f == null) {
                continue;
            }
            final Object k = f.getFeature();
            final float v = f.getValueAsFloat();

            IWeightValue old_w = model.get(k);
            if (old_w == null) {
                variance += (1.f * v * v);
            } else {
                score += (old_w.get() * v);
                variance += (old_w.getCovariance() * v * v);
            }
        }

        return new PredictionResult(score).variance(variance);
    }

    protected void update(@Nonnull final FeatureValue[] features, float coeff, Object actual_label,
            Object missed_label) {
        assert (actual_label != null);
        if (actual_label.equals(missed_label)) {
            throw new IllegalArgumentException(
                "Actual label equals to missed label: " + actual_label);
        }

        PredictionModel model2add = label2model.get(actual_label);
        if (model2add == null) {
            model2add = createModel();
            label2model.put(actual_label, model2add);
        }
        PredictionModel model2sub = null;
        if (missed_label != null) {
            model2sub = label2model.get(missed_label);
            if (model2sub == null) {
                model2sub = createModel();
                label2model.put(missed_label, model2sub);
            }
        }

        for (FeatureValue f : features) {// w[f] += y * x[f]
            if (f == null) {
                continue;
            }
            final Object k = f.getFeature();
            final float v = f.getValueAsFloat();

            float old_trueclass_w = model2add.getWeight(k);
            float add_w = old_trueclass_w + (coeff * v);
            model2add.set(k, new WeightValue(add_w));

            if (model2sub != null) {
                float old_falseclass_w = model2sub.getWeight(k);
                float sub_w = old_falseclass_w - (coeff * v);
                model2sub.set(k, new WeightValue(sub_w));
            }
        }
    }

    @Override
    public final void close() throws HiveException {
        super.close();
        if (label2model != null) {
            long numForwarded = 0L;
            long numMixed = 0L;
            if (useCovariance()) {
                final WeightValueWithCovar probe = new WeightValueWithCovar();
                final Object[] forwardMapObj = new Object[4];
                final FloatWritable fv = new FloatWritable();
                final FloatWritable cov = new FloatWritable();
                for (Map.Entry<Object, PredictionModel> entry : label2model.entrySet()) {
                    Object label = entry.getKey();
                    forwardMapObj[0] = label;
                    PredictionModel model = entry.getValue();
                    numMixed += model.getNumMixed();
                    IMapIterator<Object, IWeightValue> itor = model.entries();
                    while (itor.next() != -1) {
                        itor.getValue(probe);
                        if (!probe.isTouched()) {
                            continue; // skip outputting untouched weights
                        }
                        Object k = itor.getKey();
                        fv.set(probe.get());
                        cov.set(probe.getCovariance());
                        forwardMapObj[1] = k;
                        forwardMapObj[2] = fv;
                        forwardMapObj[3] = cov;
                        forward(forwardMapObj);
                        numForwarded++;
                    }
                }
            } else {
                final WeightValue probe = new WeightValue();
                final Object[] forwardMapObj = new Object[3];
                final FloatWritable fv = new FloatWritable();
                for (Map.Entry<Object, PredictionModel> entry : label2model.entrySet()) {
                    Object label = entry.getKey();
                    forwardMapObj[0] = label;
                    PredictionModel model = entry.getValue();
                    numMixed += model.getNumMixed();
                    IMapIterator<Object, IWeightValue> itor = model.entries();
                    while (itor.next() != -1) {
                        itor.getValue(probe);
                        if (!probe.isTouched()) {
                            continue; // skip outputting untouched weights
                        }
                        Object k = itor.getKey();
                        fv.set(probe.get());
                        forwardMapObj[1] = k;
                        forwardMapObj[2] = fv;
                        forward(forwardMapObj);
                        numForwarded++;
                    }
                }
            }
            this.label2model = null;
            logger.info("Trained a prediction model using " + count + " training examples"
                    + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
            logger.info("Forwarded the prediction model of " + numForwarded + " rows");
        }
    }

    protected void loadPredictionModel(Map<Object, PredictionModel> label2model, String filename,
            PrimitiveObjectInspector labelOI, PrimitiveObjectInspector featureOI) {
        final StopWatch elapsed = new StopWatch();
        final long lines;
        try {
            if (useCovariance()) {
                lines = loadPredictionModel(label2model, new File(filename), labelOI, featureOI,
                    writableFloatObjectInspector, writableFloatObjectInspector);
            } else {
                lines = loadPredictionModel(label2model, new File(filename), labelOI, featureOI,
                    writableFloatObjectInspector);
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to load a model: " + filename, e);
        } catch (SerDeException e) {
            throw new RuntimeException("Failed to load a model: " + filename, e);
        }
        if (!label2model.isEmpty()) {
            long totalFeatures = 0L;
            StringBuilder statsBuf = new StringBuilder(256);
            for (Map.Entry<Object, PredictionModel> e : label2model.entrySet()) {
                Object label = e.getKey();
                int numFeatures = e.getValue().size();
                statsBuf.append('\n')
                        .append("Label: ")
                        .append(label)
                        .append(", Number of Features: ")
                        .append(numFeatures);
                totalFeatures += numFeatures;
            }
            logger.info("Loaded total " + totalFeatures + " features from distributed cache '"
                    + filename + "' (" + lines + " lines) in " + elapsed + statsBuf);
        }
    }

    private long loadPredictionModel(Map<Object, PredictionModel> label2model, File file,
            PrimitiveObjectInspector labelOI, PrimitiveObjectInspector featureOI,
            WritableFloatObjectInspector weightOI) throws IOException, SerDeException {
        long count = 0L;
        if (!file.exists()) {
            return count;
        }
        if (!file.getName().endsWith(".crc")) {
            if (file.isDirectory()) {
                for (File f : file.listFiles()) {
                    count += loadPredictionModel(label2model, f, labelOI, featureOI, weightOI);
                }
            } else {
                LazySimpleSerDe serde = HiveUtils.getLineSerde(labelOI, featureOI, weightOI);
                StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
                StructField c1ref = lineOI.getStructFieldRef("c1");
                StructField c2ref = lineOI.getStructFieldRef("c2");
                StructField c3ref = lineOI.getStructFieldRef("c3");
                PrimitiveObjectInspector c1refOI =
                        (PrimitiveObjectInspector) c1ref.getFieldObjectInspector();
                PrimitiveObjectInspector c2refOI =
                        (PrimitiveObjectInspector) c2ref.getFieldObjectInspector();
                FloatObjectInspector c3refOI =
                        (FloatObjectInspector) c3ref.getFieldObjectInspector();

                BufferedReader reader = null;
                try {
                    reader = HadoopUtils.getBufferedReader(file);
                    String line;
                    while ((line = reader.readLine()) != null) {
                        count++;
                        Text lineText = new Text(line);
                        Object lineObj = serde.deserialize(lineText);
                        List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
                        Object f0 = fields.get(0);
                        Object f1 = fields.get(1);
                        Object f2 = fields.get(2);
                        if (f0 == null || f1 == null || f2 == null) {
                            continue; // avoid the case that key or value is null
                        }
                        Object label = c1refOI.getPrimitiveWritableObject(c1refOI.copyObject(f0));
                        PredictionModel model = label2model.get(label);
                        if (model == null) {
                            model = createModel();
                            label2model.put(label, model);
                        }
                        Object k = c2refOI.getPrimitiveWritableObject(c2refOI.copyObject(f1));
                        float v = c3refOI.get(f2);
                        model.set(k, new WeightValue(v, false));
                    }
                } finally {
                    IOUtils.closeQuietly(reader);
                }
            }
        }
        return count;
    }

    private long loadPredictionModel(Map<Object, PredictionModel> label2model, File file,
            PrimitiveObjectInspector labelOI, PrimitiveObjectInspector featureOI,
            WritableFloatObjectInspector weightOI, WritableFloatObjectInspector covarOI)
            throws IOException, SerDeException {
        long count = 0L;
        if (!file.exists()) {
            return count;
        }
        if (!file.getName().endsWith(".crc")) {
            if (file.isDirectory()) {
                for (File f : file.listFiles()) {
                    count += loadPredictionModel(label2model, f, labelOI, featureOI, weightOI,
                        covarOI);
                }
            } else {
                LazySimpleSerDe serde =
                        HiveUtils.getLineSerde(labelOI, featureOI, weightOI, covarOI);
                StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
                StructField c1ref = lineOI.getStructFieldRef("c1");
                StructField c2ref = lineOI.getStructFieldRef("c2");
                StructField c3ref = lineOI.getStructFieldRef("c3");
                StructField c4ref = lineOI.getStructFieldRef("c4");
                PrimitiveObjectInspector c1refOI =
                        (PrimitiveObjectInspector) c1ref.getFieldObjectInspector();
                PrimitiveObjectInspector c2refOI =
                        (PrimitiveObjectInspector) c2ref.getFieldObjectInspector();
                FloatObjectInspector c3refOI =
                        (FloatObjectInspector) c3ref.getFieldObjectInspector();
                FloatObjectInspector c4refOI =
                        (FloatObjectInspector) c4ref.getFieldObjectInspector();

                BufferedReader reader = null;
                try {
                    reader = HadoopUtils.getBufferedReader(file);
                    String line;
                    while ((line = reader.readLine()) != null) {
                        count++;
                        Text lineText = new Text(line);
                        Object lineObj = serde.deserialize(lineText);
                        List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
                        Object f0 = fields.get(0);
                        Object f1 = fields.get(1);
                        Object f2 = fields.get(2);
                        Object f3 = fields.get(3);
                        if (f0 == null || f1 == null || f2 == null) {
                            continue; // avoid unexpected case
                        }
                        Object label = c1refOI.getPrimitiveWritableObject(c1refOI.copyObject(f0));
                        PredictionModel model = label2model.get(label);
                        if (model == null) {
                            model = createModel();
                            label2model.put(label, model);
                        }
                        Object k = c2refOI.getPrimitiveWritableObject(c2refOI.copyObject(f1));
                        float v = c3refOI.get(f2);
                        float cov =
                                (f3 == null) ? WeightValueWithCovar.DEFAULT_COVAR : c4refOI.get(f3);
                        model.set(k, new WeightValueWithCovar(v, cov, false));
                    }
                } finally {
                    IOUtils.closeQuietly(reader);
                }
            }
        }
        return count;
    }
}