/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall;

import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionModel;
import hivemall.model.WeightValue;
import hivemall.model.WeightValue.WeightValueWithCovar;
import hivemall.optimizer.LossFunctions;
import hivemall.optimizer.LossFunctions.LossFunction;
import hivemall.optimizer.LossFunctions.LossType;
import hivemall.optimizer.Optimizer;
import hivemall.optimizer.OptimizerOptions;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.FileUtils;
import hivemall.utils.io.NIOUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.FloatAccumulator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;

public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
    private static final Log logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class);
    private static final float MAX_DLOSS = 1e+12f;
    private static final float MIN_DLOSS = -1e+12f;

    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector targetOI;
    private FeatureType featureType;

    // -----------------------------------------
    // hyperparameters

    @Nonnull
    private final Map<String, String> optimizerOptions;
    private Optimizer optimizer;
    private LossFunction lossFunction;

    // -----------------------------------------

    private PredictionModel model;
    private long count;

    // -----------------------------------------
    // for mini-batch

    /** The accumulated delta of each weight values. */
    @Nullable
    private transient Map<Object, FloatAccumulator> accumulated;
    private int sampled;

    // -----------------------------------------
    // for iterations

    @Nullable
    protected transient NioStatefulSegment fileIO;
    @Nullable
    protected transient ByteBuffer inputBuf;
    private int iterations;
    protected ConversionState cvState;

    // -----------------------------------------

    public GeneralLearnerBaseUDTF() {
        this(true);
    }

    public GeneralLearnerBaseUDTF(boolean enableNewModel) {
        super(enableNewModel);
        this.optimizerOptions = OptimizerOptions.create();
    }

    @Nonnull
    protected abstract String getLossOptionDescription();

    @Nonnull
    protected abstract LossType getDefaultLossType();

    protected abstract void checkLossFunction(@Nonnull LossFunction lossFunction)
            throws UDFArgumentException;

    protected abstract void checkTargetValue(float target) throws UDFArgumentException;

    protected abstract void train(@Nonnull final FeatureValue[] features, final float target);

    @Override
    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if (argOIs.length < 2) {
            showHelp(
                "_FUNC_ takes two or three arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
        }
        this.featureListOI = HiveUtils.asListOI(argOIs, 0);
        this.featureType = getFeatureType(featureListOI);
        this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs, 1);

        processOptions(argOIs);

        this.model = createModel();

        try {
            this.optimizer = createOptimizer(optimizerOptions);
        } catch (Throwable e) {
            throw new UDFArgumentException(e);
        }

        this.count = 0L;
        this.sampled = 0;

        return getReturnOI(getFeatureOutputOI(featureType));
    }

    @Override
    protected Options getOptions() {
        Options opts = super.getOptions();
        opts.addOption("inspect_opts", false, "Inspect Optimizer options");
        opts.addOption("loss", "loss_function", true, getLossOptionDescription());
        opts.addOption("iter", "iterations", true,
            "The maximum number of iterations [default: 10]");
        opts.addOption("iters", "iterations", true,
            "The maximum number of iterations [default: 10]");
        // conversion check
        opts.addOption("disable_cv", "disable_cvtest", false,
            "Whether to disable convergence check [default: OFF]");
        opts.addOption("cv_rate", "convergence_rate", true,
            "Threshold to determine convergence [default: 0.005]");
        OptimizerOptions.setup(opts);
        return opts;
    }

    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
        CommandLine cl = super.processOptions(argOIs);

        LossFunction lossFunction = LossFunctions.getLossFunction(getDefaultLossType());
        int iterations = 10;
        boolean conversionCheck = true;
        double convergenceRate = 0.005d;

        if (cl != null) {
            if (cl.hasOption("loss_function")) {
                try {
                    lossFunction =
                            LossFunctions.getLossFunction(cl.getOptionValue("loss_function"));
                } catch (Throwable e) {
                    throw new UDFArgumentException(e.getMessage());
                }
            }
            checkLossFunction(lossFunction);

            iterations = Primitives.parseInt(cl.getOptionValue("iterations"), iterations);
            if (iterations < 1) {
                throw new UDFArgumentException(
                    "'-iterations' must be greater than or equals to 1: " + iterations);
            }

            conversionCheck = !cl.hasOption("disable_cvtest");
            convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate);
        }

        this.lossFunction = lossFunction;
        this.iterations = iterations;
        this.cvState = new ConversionState(conversionCheck, convergenceRate);

        OptimizerOptions.processOptions(cl, optimizerOptions);

        if (cl != null && cl.hasOption("inspect_opts")) {
            Optimizer optimizer = createOptimizer(optimizerOptions);
            Map<String, Object> params = optimizer.getHyperParameters();
            params.put("loss_function", lossFunction.getType().toString());
            params.put("iterations", iterations);
            params.put("disable_cvtest", conversionCheck ? false : true);
            params.put("cv_rate", convergenceRate);
            throw new UDFArgumentException(
                String.format("Inspected Optimizer options ...\n%s", params.toString()));
        }

        return cl;
    }

    public enum FeatureType {
        STRING, INT, LONG
    }

    @Nonnull
    private static FeatureType getFeatureType(@Nonnull ListObjectInspector featureListOI)
            throws UDFArgumentException {
        final ObjectInspector featureOI = featureListOI.getListElementObjectInspector();
        if (featureOI instanceof StringObjectInspector) {
            return FeatureType.STRING;
        } else if (featureOI instanceof IntObjectInspector) {
            return FeatureType.INT;
        } else if (featureOI instanceof LongObjectInspector) {
            return FeatureType.LONG;
        } else {
            throw new UDFArgumentException("Feature object inspector must be one of "
                    + "[StringObjectInspector, IntObjectInspector, LongObjectInspector]: "
                    + featureOI.toString());
        }
    }

    @Nonnull
    protected final ObjectInspector getFeatureOutputOI(@Nonnull final FeatureType featureType)
            throws UDFArgumentException {
        final PrimitiveObjectInspector outputOI;
        if (dense_model) {
            // TODO validation
            outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector; // see DenseModel (long/string is also parsed as int)
        } else {
            switch (featureType) {
                case STRING:
                    outputOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                    break;
                case INT:
                    outputOI = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
                    break;
                case LONG:
                    outputOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
                    break;
                default:
                    throw new IllegalStateException("Unexpected feature type: " + featureType);
            }
        }
        return outputOI;
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector featureOutputOI) {
        ArrayList<String> fieldNames = new ArrayList<String>();
        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

        fieldNames.add("feature");
        fieldOIs.add(featureOutputOI);
        fieldNames.add("weight");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (useCovariance()) {
            fieldNames.add("covar");
            fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public void process(Object[] args) throws HiveException {
        if (is_mini_batch && accumulated == null) {
            this.accumulated = new HashMap<Object, FloatAccumulator>(1024);
        }

        List<?> features = (List<?>) featureListOI.getList(args[0]);
        FeatureValue[] featureVector = parseFeatures(features);
        if (featureVector == null) {
            return;
        }
        float target = PrimitiveObjectInspectorUtils.getFloat(args[1], targetOI);
        checkTargetValue(target);

        count++;
        train(featureVector, target);

        recordTrainSampleToTempFile(featureVector, target);
    }

    protected void recordTrainSampleToTempFile(@Nonnull final FeatureValue[] featureVector,
            final float target) throws HiveException {
        if (iterations == 1) {
            return;
        }

        ByteBuffer buf = inputBuf;
        NioStatefulSegment dst = fileIO;

        if (buf == null) {
            final File file;
            try {
                file = File.createTempFile("hivemall_general_learner", ".sgmt");
                file.deleteOnExit();
                if (!file.canWrite()) {
                    throw new UDFArgumentException(
                        "Cannot write a temporary file: " + file.getAbsolutePath());
                }
                logger.info("Record training samples to a file: " + file.getAbsolutePath());
            } catch (IOException ioe) {
                throw new UDFArgumentException(ioe);
            } catch (Throwable e) {
                throw new UDFArgumentException(e);
            }
            this.inputBuf = buf = ByteBuffer.allocateDirect(2 * 1024 * 1024); // 2 MB
            this.fileIO = dst = new NioStatefulSegment(file, false);
        }

        int featureVectorBytes = 0;
        for (FeatureValue f : featureVector) {
            if (f == null) {
                continue;
            }
            int featureLength = f.getFeatureAsString().length();

            // feature as String (even if it is Text or Integer)
            featureVectorBytes += SizeOf.CHAR * featureLength;

            // NIOUtils.putString() first puts the length of string before string itself
            featureVectorBytes += SizeOf.INT;

            // value
            featureVectorBytes += SizeOf.DOUBLE;
        }

        // feature length, feature 1, feature 2, ..., feature n, target
        int recordBytes = SizeOf.INT + featureVectorBytes + SizeOf.FLOAT;
        int requiredBytes = SizeOf.INT + recordBytes; // need to allocate space for "recordBytes" itself

        int remain = buf.remaining();
        if (remain < requiredBytes) {
            writeBuffer(buf, dst);
        }
        if (requiredBytes > buf.remaining()) {
            throw new HiveException("Buffer size (2MB) for writing training example is not enough: "
                    + NumberUtils.prettySize(requiredBytes));
        }

        buf.putInt(recordBytes);
        buf.putInt(featureVector.length);
        for (FeatureValue f : featureVector) {
            writeFeatureValue(buf, f);
        }
        buf.putFloat(target);
    }

    private static void writeFeatureValue(@Nonnull final ByteBuffer buf,
            @Nonnull final FeatureValue f) {
        NIOUtils.putString(f.getFeatureAsString(), buf);
        buf.putDouble(f.getValue());
    }

    @Nonnull
    private static FeatureValue readFeatureValue(@Nonnull final ByteBuffer buf,
            @Nonnull final FeatureType featureType) {
        final String featureStr = NIOUtils.getString(buf);
        final Object feature;
        switch (featureType) {
            case STRING:
                feature = featureStr;
                break;
            case INT:
                feature = Integer.valueOf(featureStr);
                break;
            case LONG:
                feature = Long.valueOf(featureStr);
                break;
            default:
                throw new IllegalStateException(
                    "Unexpected feature type " + featureType + " for feature: " + featureStr);
        }
        double value = buf.getDouble();
        return new FeatureValue(feature, value);
    }

    @Nullable
    public final FeatureValue[] parseFeatures(@Nonnull final List<?> features) {
        final int size = features.size();
        if (size == 0) {
            return null;
        }

        final ObjectInspector featureInspector = featureListOI.getListElementObjectInspector();
        final FeatureValue[] featureVector = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object f = features.get(i);
            if (f == null) {
                continue;
            }
            final FeatureValue fv;
            if (featureType == FeatureType.STRING) {
                String s = f.toString();
                fv = FeatureValue.parseFeatureAsString(s);
            } else {
                Object k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector,
                    ObjectInspectorCopyOption.JAVA); // should be Integer or Long
                fv = new FeatureValue(k, 1.f);
            }
            featureVector[i] = fv;
        }
        return featureVector;
    }

    private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefulSegment dst)
            throws HiveException {
        srcBuf.flip();
        try {
            dst.write(srcBuf);
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", e);
        }
        srcBuf.clear();
    }

    public float predict(@Nonnull final FeatureValue[] features) {
        float score = 0.f;
        for (FeatureValue f : features) {// a += w[i] * x[i]
            if (f == null) {
                continue;
            }
            final Object k = f.getFeature();
            final float v = f.getValueAsFloat();

            float old_w = model.getWeight(k);
            if (old_w != 0.f) {
                score += (old_w * v);
            }
        }
        return score;
    }

    protected void update(@Nonnull final FeatureValue[] features, final float target,
            final float predicted) {
        optimizer.proceedStep();

        float loss = lossFunction.loss(predicted, target);
        cvState.incrLoss(loss); // retain cumulative loss to check convergence

        float dloss = lossFunction.dloss(predicted, target);
        if (dloss == 0.f) {
            return;
        }
        if (dloss < MIN_DLOSS) {
            dloss = MIN_DLOSS;
        } else if (dloss > MAX_DLOSS) {
            dloss = MAX_DLOSS;
        }

        if (is_mini_batch) {
            accumulateUpdate(features, loss, dloss);
            if (sampled >= mini_batch_size) {
                batchUpdate();
            }
        } else {
            onlineUpdate(features, loss, dloss);
        }
    }

    protected void accumulateUpdate(@Nonnull final FeatureValue[] features, final float loss,
            final float dloss) {
        for (FeatureValue f : features) {
            Object feature = f.getFeature();
            float xi = f.getValueAsFloat();
            float weight = model.getWeight(feature);

            // compute new weight, but still not set to the model
            float gradient = dloss * xi;
            float new_weight = optimizer.update(feature, weight, loss, gradient);

            // (w_i - eta * delta_1) + (w_i - eta * delta_2) + ... + (w_i - eta * delta_M)
            FloatAccumulator acc = accumulated.get(feature);
            if (acc == null) {
                acc = new FloatAccumulator(new_weight);
                accumulated.put(feature, acc);
            } else {
                acc.add(new_weight);
            }
        }
        sampled++;
    }

    protected void batchUpdate() {
        if (accumulated.isEmpty()) {
            this.sampled = 0;
            return;
        }

        for (Map.Entry<Object, FloatAccumulator> e : accumulated.entrySet()) {
            Object feature = e.getKey();
            FloatAccumulator v = e.getValue();
            final float new_weight = v.get(); // w_i - (eta / M) * (delta_1 + delta_2 + ... + delta_M)
            if (new_weight == 0.f) {
                model.delete(feature);
                continue;
            }
            model.setWeight(feature, new_weight);
        }

        accumulated.clear();
        this.sampled = 0;
    }

    protected void onlineUpdate(@Nonnull final FeatureValue[] features, final float loss,
            final float dloss) {
        for (FeatureValue f : features) {
            Object feature = f.getFeature();
            float xi = f.getValueAsFloat();
            float weight = model.getWeight(feature);
            float gradient = dloss * xi;
            final float new_weight = optimizer.update(feature, weight, loss, gradient);
            if (new_weight == 0.f) {
                model.delete(feature);
                continue;
            }
            model.setWeight(feature, new_weight);
        }
    }

    @Override
    public final void close() throws HiveException {
        super.close();
        finalizeTraining();
        forwardModel();
        this.accumulated = null;
        this.model = null;
    }

    @VisibleForTesting
    public void finalizeTraining() throws HiveException {
        if (count == 0L) {
            this.model = null;
            return;
        }
        if (is_mini_batch) { // Update model with accumulated delta
            batchUpdate();
        }
        if (iterations > 1) {
            runIterativeTraining(iterations);
        }
    }

    protected final void runIterativeTraining(@Nonnegative final int iterations)
            throws HiveException {
        final ByteBuffer buf = this.inputBuf;
        final NioStatefulSegment dst = this.fileIO;
        assert (buf != null);
        assert (dst != null);
        final long numTrainingExamples = count;

        final Reporter reporter = getReporter();
        final Counters.Counter iterCounter = (reporter == null) ? null
                : reporter.getCounter("hivemall.GeneralLearnerBase$Counter", "iteration");

        try {
            if (dst.getPosition() == 0L) {// run iterations w/o temporary file
                if (buf.position() == 0) {
                    return; // no training example
                }
                buf.flip();

                for (int iter = 2; iter <= iterations; iter++) {
                    cvState.next();
                    reportProgress(reporter);
                    setCounterValue(iterCounter, iter);

                    while (buf.remaining() > 0) {
                        int recordBytes = buf.getInt();
                        assert (recordBytes > 0) : recordBytes;
                        int featureVectorLength = buf.getInt();
                        final FeatureValue[] featureVector = new FeatureValue[featureVectorLength];
                        for (int j = 0; j < featureVectorLength; j++) {
                            featureVector[j] = readFeatureValue(buf, featureType);
                        }
                        float target = buf.getFloat();
                        train(featureVector, target);
                    }
                    buf.rewind();

                    if (is_mini_batch) { // Update model with accumulated delta
                        batchUpdate();
                    }

                    if (cvState.isConverged(numTrainingExamples)) {
                        break;
                    }
                }
                logger.info("Performed " + cvState.getCurrentIteration() + " iterations of "
                        + NumberUtils.formatNumber(numTrainingExamples)
                        + " training examples on memory (thus "
                        + NumberUtils.formatNumber(
                            numTrainingExamples * cvState.getCurrentIteration())
                        + " training updates in total) ");
            } else {// read training examples in the temporary file and invoke train for each example
                // write training examples in buffer to a temporary file
                if (buf.remaining() > 0) {
                    writeBuffer(buf, dst);
                }
                try {
                    dst.flush();
                } catch (IOException e) {
                    throw new HiveException(
                        "Failed to flush a file: " + dst.getFile().getAbsolutePath(), e);
                }
                if (logger.isInfoEnabled()) {
                    File tmpFile = dst.getFile();
                    logger.info("Wrote " + numTrainingExamples
                            + " records to a temporary file for iterative training: "
                            + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile)
                            + ")");
                }

                // run iterations
                for (int iter = 2; iter <= iterations; iter++) {
                    cvState.next();
                    setCounterValue(iterCounter, iter);

                    buf.clear();
                    dst.resetPosition();
                    while (true) {
                        reportProgress(reporter);
                        // TODO prefetch
                        // writes training examples to a buffer in the temporary file
                        final int bytesRead;
                        try {
                            bytesRead = dst.read(buf);
                        } catch (IOException e) {
                            throw new HiveException(
                                "Failed to read a file: " + dst.getFile().getAbsolutePath(), e);
                        }
                        if (bytesRead == 0) { // reached file EOF
                            break;
                        }
                        assert (bytesRead > 0) : bytesRead;

                        // reads training examples from a buffer
                        buf.flip();
                        int remain = buf.remaining();
                        if (remain < SizeOf.INT) {
                            throw new HiveException("Illegal file format was detected");
                        }
                        while (remain >= SizeOf.INT) {
                            int pos = buf.position();
                            int recordBytes = buf.getInt();
                            remain -= SizeOf.INT;

                            if (remain < recordBytes) {
                                buf.position(pos);
                                break;
                            }

                            int featureVectorLength = buf.getInt();
                            final FeatureValue[] featureVector =
                                    new FeatureValue[featureVectorLength];
                            for (int j = 0; j < featureVectorLength; j++) {
                                featureVector[j] = readFeatureValue(buf, featureType);
                            }
                            float target = buf.getFloat();
                            train(featureVector, target);

                            remain -= recordBytes;
                        }
                        buf.compact();
                    }

                    if (is_mini_batch) { // Update model with accumulated delta
                        batchUpdate();
                    }

                    if (cvState.isConverged(numTrainingExamples)) {
                        break;
                    }
                }
                logger.info("Performed " + cvState.getCurrentIteration() + " iterations of "
                        + NumberUtils.formatNumber(numTrainingExamples)
                        + " training examples on a secondary storage (thus "
                        + NumberUtils.formatNumber(
                            numTrainingExamples * cvState.getCurrentIteration())
                        + " training updates in total)");
            }
        } catch (Throwable e) {
            throw new HiveException("Exception caused in the iterative training", e);
        } finally {
            // delete the temporary file and release resources
            try {
                dst.close(true);
            } catch (IOException e) {
                throw new HiveException(
                    "Failed to close a file: " + dst.getFile().getAbsolutePath(), e);
            }
            this.inputBuf = null;
            this.fileIO = null;
        }
    }

    protected void forwardModel() throws HiveException {
        int numForwarded = 0;
        if (useCovariance()) {
            final WeightValueWithCovar probe = new WeightValueWithCovar();
            final Object[] forwardMapObj = new Object[3];
            final FloatWritable fv = new FloatWritable();
            final FloatWritable cov = new FloatWritable();
            final IMapIterator<Object, IWeightValue> itor = model.entries();
            while (itor.next() != -1) {
                itor.getValue(probe);
                if (!probe.isTouched()) {
                    continue; // skip outputting untouched weights
                }
                final float v = probe.get();
                final float cv = probe.getCovariance();
                if (v == 0.f && cv == 0.f) {
                    continue;
                }
                fv.set(v);
                cov.set(cv);
                Object k = itor.getKey();
                forwardMapObj[0] = k;
                forwardMapObj[1] = fv;
                forwardMapObj[2] = cov;
                forward(forwardMapObj);
                numForwarded++;
            }
        } else {
            final WeightValue probe = new WeightValue();
            final Object[] forwardMapObj = new Object[2];
            final FloatWritable fv = new FloatWritable();
            final IMapIterator<Object, IWeightValue> itor = model.entries();
            while (itor.next() != -1) {
                itor.getValue(probe);
                if (!probe.isTouched()) {
                    continue; // skip outputting untouched weights
                }
                final float v = probe.get();
                if (v == 0.f) {
                    continue;
                }
                fv.set(v);
                Object k = itor.getKey();
                forwardMapObj[0] = k;
                forwardMapObj[1] = fv;
                forward(forwardMapObj);
                numForwarded++;
            }
        }
        long numMixed = model.getNumMixed();
        logger.info("Trained a prediction model using " + count + " training examples"
                + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
        logger.info("Forwarded the prediction model of " + numForwarded + " rows");
    }

    @VisibleForTesting
    public double getCumulativeLoss() {
        return (cvState == null) ? Double.NaN : cvState.getCumulativeLoss();
    }
}