package org.sgdtk.struct;

import org.sgdtk.CollectionsManip;

import java.io.InputStream;
import java.io.OutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.IOException;
import java.util.Arrays;

/**
 * CRF implementation of a {@link org.sgdtk.struct.SequentialModel}
 *
 * Provides the represention of just the CRF model itself, including the weight vector,
 * and the methods for persisting and restoring its weights, along with predicting a y for an x (fv)
 *
 * @author dpressel
 */
public class CRFModel implements SequentialModel
{

    private CRFModel(double[] weights, double wscale, int numLabels)
    {
        this.weights = new double[weights.length];
        System.arraycopy(weights, 0, this.weights, 0, weights.length);
        this.numLabels = numLabels;
        this.wscale = wscale;
    }

    /**
     * Default constructor.  This is usually only going to be called prior to a {@link #load(java.io.InputStream)} call
     */
    public CRFModel()
    {

    }

    /**
     * Construct a model prior to training. This just establishes the extent of the weight vector and the number of
     * sequence y values (labels).  Dont use this unless you understand what you are doing
     *
     * @param wlength This is the weight vector's width
     * @param wscale scaling
     * @param numLabels number of labels
     */
    public CRFModel(int wlength, double wscale, int numLabels)
    {

        this.weights = new double[wlength];
        Arrays.fill(weights, 0.);
        this.wscale = wscale;
        this.numLabels = numLabels;

    }

    /**
     * Load from a stream.  This is how you would load a model to use the classifier.
     *
     * @param inputStream
     * @throws IOException
     */
    @Override
    public void load(InputStream inputStream) throws IOException
    {
        ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);
        wscale = objectInputStream.readDouble();
        numLabels = (int) objectInputStream.readLong();
        int length = (int) objectInputStream.readLong();
        weights = new double[length];
        for (int i = 0; i < length; ++i)
        {
            weights[i] = objectInputStream.readDouble();
        }
        objectInputStream.close();
    }

    /**
     * Save the weight vector, etc to a stream
     *
     * @param outputStream Stream to save to
     * @throws IOException
     */
    @Override
    public void save(OutputStream outputStream) throws IOException
    {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(outputStream);
        objectOutputStream.writeDouble(wscale);
        objectOutputStream.writeLong((long)numLabels);
        objectOutputStream.writeLong((long)weights.length);
        for (int i = 0; i < weights.length; ++i)
        {
            objectOutputStream.writeDouble(weights[i]);
        }
        objectOutputStream.close();
    }

    private double[] weights;
    private double wscale;
    private int numLabels;

    /**
     * Use viterbi algorithm to find best path.  Note that this method does not hydrate the results, that would
     * violate the intended separation of concerns.  However, hydrating the results its shown in demonstration code
     * @see org.sgdtk.exec.EvalStruct#evalOneMaybePrint(SequentialLearner, SequentialModel, FeatureVectorSequence, org.sgdtk.FeatureNameEncoder, org.sgdtk.Metrics)
     * @param sequence The sequence to predict.
     */
    @Override
    public Path predict(FeatureVectorSequence sequence)
    {
        Scorer scorer = new Scorer(this, sequence);
        return scorer.viterbi();
    }


    /**
     * Create a deep copy of this exact model
     *
     * @return A clone
     */
    @Override
    public SequentialModel prototype()
    {
        return new CRFModel(weights, wscale, numLabels);
    }

    /**
     * Get weights
     * @return weights
     */
    public double[] getWeights()
    {
        return weights;
    }

    /**
     * Get wscale
     * @return wscale
     */
    public double getWscale()
    {
        return wscale;
    }

    /**
     * Get w' w scaled
     * @return mag
     */
    public double mag()
    {
        double dotProd = CollectionsManip.dot(weights, weights);
        return dotProd * wscale * wscale;
    }

    /**
     * Rescale the vector and reset the wscale
     */
    public void rescale()
    {
        if (wscale != 1.0)
        {
            for (int i = 0; i < weights.length; ++i)
            {
                weights[i] *= wscale;
            }
            wscale = 1;
        }
    }

    /**
     * Set the wscale
     * @param wscale the scaling
     */
    public void setWscale(double wscale)
    {
        this.wscale = wscale;
    }

    /**
     * Get the number of labels (or classes) in this model
     * @return
     */
    public int getNumLabels()
    {
        return numLabels;
    }
}