package edu.neu.ccs.pyramid.classification.logistic_regression;

import edu.neu.ccs.pyramid.dataset.SerializableVector;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorView;

import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * Created by chengli on 12/7/14.
 */
public class Weights implements Serializable {
    private static final long serialVersionUID = 2L;
    private int numClasses;
    private int numFeatures;
    /**
     * vector is not serializable
     */
    private transient Vector weightVector;


    public Weights(int numClasses, int numFeatures, boolean random) {
        if (random) {
            this.numClasses = numClasses;
            this.numFeatures = numFeatures;
            this.weightVector = new DenseVector((numFeatures + 1)*numClasses);
            Random randomGenerator = new Random(0L);
            for (int i=0; i<weightVector.size(); i++) {
                double p = randomGenerator.nextDouble()-0.5;
                weightVector.set(i,p);
            }
        } else {
            this.numClasses = numClasses;
            this.numFeatures = numFeatures;
            this.weightVector = new DenseVector((numFeatures + 1)*numClasses);
        }

    }

    public Weights(int numClasses, int numFeatures) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        this.weightVector = new DenseVector((numFeatures + 1)*numClasses);
    }

    public Weights(int numClasses, int numFeatures, Vector weightVector) {
        this.numClasses = numClasses;
        this.numFeatures = numFeatures;
        if (weightVector.size()!=(numFeatures + 1)*numClasses){
            throw new IllegalArgumentException("weightVector.size()!=(numFeatures + 1)*numClasses");
        }
        this.weightVector = weightVector;
    }

    public void setWeightVector(Vector weightVector) {
        if (weightVector.size()!=(numFeatures + 1)*numClasses){
            throw new IllegalArgumentException("weightVector.size()!=(numFeatures + 1)*numClasses");
        }
        this.weightVector = weightVector;
    }

    public Weights deepCopy(){
        Weights copy = new Weights(this.numClasses,numFeatures);
        copy.weightVector = new DenseVector(this.weightVector);
        return copy;
    }

    public int getClassIndex(int parameterIndex){
        return parameterIndex/(numFeatures+1);
    }

    /**
     *
     * @param parameterIndex
     * @return feature index
     * -1 means bias
     */
    public int getFeatureIndex(int parameterIndex){
        return parameterIndex - getClassIndex(parameterIndex)*(numFeatures+1) -1;
    }

    public List<Integer> getAllBiasPositions(){
        List<Integer> list = new ArrayList<>();
        for (int k=0;k<numClasses;k++){
            list.add((this.numFeatures+1)*k);
        }
        return list;
    }

    /**
     *
     * @return weights for all classes
     */
    public Vector getAllWeights() {
        return weightVector;
    }

    public int totalSize(){
        return weightVector.size();
    }

    /**
     * truncate the weights below the threshold to 0
     * @param threshold
     */
    void truncateByThreshold(double threshold){
        for (int k=0;k<numClasses;k++){
            Vector vector  = getWeightsWithoutBiasForClass(k);
            for (int d=0;d<vector.size();d++){
                if (Math.abs(vector.get(d))<threshold){
                    vector.set(d,0);
                }
            }
        }
    }

    /**
     *
     * @param k class index
     * @return weights for class k, including bias at the beginning
     */
    public Vector getWeightsForClass(int k){
        if (k>=numClasses){
            throw new IllegalArgumentException("out of bound");
        }
        int start = (this.numFeatures+1)*k;
        int length = this.numFeatures +1;
        return new VectorView(this.weightVector,start,length);
    }

    /**
     *
     * @param k
     * @return weights for class k, no bias
     */
    public Vector getWeightsWithoutBiasForClass(int k){
        if (k>=numClasses){
            throw new IllegalArgumentException("out of bound");
        }
        int start = (this.numFeatures+1)*k + 1;
        int length = this.numFeatures;
        return new VectorView(this.weightVector,start,length);
    }

    /**
     *
     * @param k
     * @return bias
     */
    public double getBiasForClass(int k){
        if (k>=numClasses){
            throw new IllegalArgumentException("out of bound");
        }
        int start = (this.numFeatures+1)*k;
        return this.weightVector.get(start);
    }



    public void setBiasForClass(double bias, int k){
        if (k>=numClasses){
            throw new IllegalArgumentException("out of bound");
        }
        int start = (this.numFeatures+1)*k;
        this.weightVector.set(start, bias);
    }

    private void writeObject(java.io.ObjectOutputStream out)
            throws IOException {
        out.writeInt(numClasses);
        out.writeInt(numFeatures);
        int numNonZeros = weightVector.getNumNonZeroElements();
        int[] indices = new int[numNonZeros];
        double[] values = new double[numNonZeros];
        int i=0;
        for (Vector.Element element: weightVector.nonZeroes()){
            int index = element.index();
            double v = element.get();
            indices[i] = index;
            values[i] = v;
            i += 1;
        }
        out.writeObject(indices);
        out.writeObject(values);

    }
    private void readObject(java.io.ObjectInputStream in)
            throws IOException, ClassNotFoundException{
        numClasses = in.readInt();
        numFeatures = in.readInt();
        int[] indices = (int[])in.readObject();
        double[] values = (double[])in.readObject();
        weightVector = new DenseVector((numFeatures + 1)*numClasses);
        for (int i=0;i<indices.length;i++){
            weightVector.set(indices[i],values[i]);
        }
    }

    void serialize(File file) throws Exception{
        File parent = file.getParentFile();
        if (!parent.exists()){
            parent.mkdirs();
        }
        try (
                FileOutputStream fileOutputStream = new FileOutputStream(file);
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(bufferedOutputStream);
        ){
            objectOutputStream.writeObject(this);
        }
    }

    public static Weights deserialize(File file) throws Exception{
        try(
                FileInputStream fileInputStream = new FileInputStream(file);
                BufferedInputStream bufferedInputStream = new BufferedInputStream(fileInputStream);
                ObjectInputStream objectInputStream = new ObjectInputStream(bufferedInputStream);
        ){
            return (Weights)objectInputStream.readObject();
        }
    }

    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder("Weights{");
        for (int k=0;k<numClasses;k++){
            sb.append("for class ").append(k).append(":").append("\n");
            sb.append("bias = "+getBiasForClass(k)).append(",");
            sb.append("weights = "+getWeightsWithoutBiasForClass(k)).append("\n");
        }
        sb.append('}');
        return sb.toString();
    }
}