/**************************************************************************
 * Developed by Language Technologies Institute, Carnegie Mellon University
 * Written by Richard Wang (rcwang#cs,cmu,edu)
 **************************************************************************/
package com.rcwang.seal.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import org.apache.log4j.Logger;

import com.rcwang.seal.util.StringFactory.SID;

public class VotedPerceptron {
  
  public static class WeightedPerceptron {
    public double weight = 0;
    public double[] perceptron;
    public WeightedPerceptron(int size) { perceptron = new double[size]; }
    public WeightedPerceptron(double[] perceptron) { this.perceptron = perceptron; }
  }

  public static int MIN_PERCEPTRON_WEIGHT = 1;
  public static boolean FAIL_EARLY = false;
  public static Logger log = Logger.getLogger(VotedPerceptron.class);
  
  private List<WeightedPerceptron> weightedPerceptrons;
  private Map<SID, Double> trueLabels;
  private SparseMatrix matrix;
  private boolean moreInstancesAdded;
  private int numEpochs;

  public static void main(String[] args) {
    VotedPerceptron vp = new VotedPerceptron(10);
    
    // training
    double[][] train = new double[][] {
        {1, 0, 1, 1, 0, 1},
        {1, 0, 0, 0, 1, 0},
        {1, 1, 1, 1, 1, 1},
        {0, 1, 0, 1, 0, 1},
        {0, 0, 0, 0, 0, 0},
    };
    double[] labels = new double[] {
        1, 1, 1, -1, -1
    };
    for (int i = 0; i < train.length; i++)
      for (int j = 0; j < train[i].length; j++)
        vp.addExample(String.valueOf(i), String.valueOf(j), train[i][j], labels[i]);
    
    // testing
    double[][] test = new double[][] {
        {1, 0, 0, 0, 0, 0},
        {0, 1, 1, 1, 1, 1},
        {0, 0, 1, 0, 0, 0},
        {0, 0, 0, 0, 0, 0},
    };
    SparseVector instance = new SparseVector();
    for (int i = 0; i < test.length; i++) {
      instance.clear();
      for (int j = 0; j < test[i].length; j++)
        instance.put(String.valueOf(j), test[i][j]);
      double votes = vp.vote(instance);
      log.info("Prediction for " + instance + " is " + votes + " (" + (sign(votes) > 0 ? "TRUE" : "FALSE") + ")");
    }
  }
  
  public static int sign(double d) {
//    return (d >= 0) ? 1 : -1;
    return (d > 0) ? 1 : (d < 0 ? -1 : 0);
  }
  
  public VotedPerceptron(int numEpochs) {
    this.numEpochs = numEpochs;
    moreInstancesAdded = true;
    matrix = new SparseMatrix();
    trueLabels = new HashMap<SID, Double>();
    weightedPerceptrons = new ArrayList<WeightedPerceptron>();
  }
  
  public void addExample(SID instanceID, SID featureID, double weight, double label) {
    matrix.add(featureID, instanceID, weight);
    trueLabels.put(instanceID, label);
    moreInstancesAdded = true;
  }
  
  public void addExample(SID instanceID, SparseVector featureVector, double label) {
    for (Entry<SID, Cell> entry : featureVector) {
      SID feautreID = entry.getKey();
      double value = entry.getValue().value;
      addExample(instanceID, feautreID, value, label);
    }
  }
  
  public void addExample(String instanceName, String featureName, double weight, double label) {
    SID instanceID = StringFactory.toID(instanceName);
    SID featureID = StringFactory.toID(featureName);
    addExample(instanceID, featureID, weight, label);
  }
  
  public double dotProduct(double[] perceptron, SparseVector instance) {
    int sum = 0, index = 0;
    for (SID featureID : matrix.getColumnIDs()) {
      Cell cell = instance.get(featureID);
      double x = (cell == null) ? 0 : cell.value;
      sum += perceptron[index++] * x;
    }
    return sum;
  }
  
  public int getNumExamples() {
    return matrix.getRowIDs().size();
  }
  
  public int getNumFeatures() {
    return matrix.getColumnIDs().size();
  }
  
  public void reset() {
    moreInstancesAdded = true;
    matrix.clear();
    trueLabels.clear();
    weightedPerceptrons.clear();
  }
  
  public double vote(SparseVector instance) {
    if (moreInstancesAdded) {
      batchTrain();
      moreInstancesAdded = false;
    }
    double votes = 0;
    for (WeightedPerceptron weightedPerceptron : weightedPerceptrons)
      votes += weightedPerceptron.weight * sign(dotProduct(weightedPerceptron.perceptron, instance));
    return votes;
  }
  
  private void batchTrain() {
    log.info("Training on " + getNumExamples() + " examples...");
    long startTime = System.currentTimeMillis();
    WeightedPerceptron weightedPerceptron = new WeightedPerceptron(matrix.getNumColumns());
    weightedPerceptrons.add(weightedPerceptron);
    
    for (int i = 0; i < numEpochs; i++) {
      String epochID = "[" + (i+1) + "/" + numEpochs + "]";
      log.info(epochID + " Training (constructed " + weightedPerceptrons.size() + " perceptrons so far)...");
      
      for (SparseVector instance : matrix.getRows()) {
        int binaryTrueLabel = sign(trueLabels.get(instance.id));
        int binaryPredictedLabel = sign(dotProduct(weightedPerceptron.perceptron, instance));
//        log.info(epochID + instance.id + " Instance: " + instance + ", Prediction: " + binaryPredictedLabel + ", True: " + binaryTrueLabel);
        
        if (FAIL_EARLY && weightedPerceptron.weight == matrix.getNumRows()) {
          // current perceptron has survived through an epoch; so terminate training process
          i = numEpochs;
          break;
        } else if (binaryTrueLabel != binaryPredictedLabel) {
          // incorrect prediction; so update and create a new perceptron
          weightedPerceptron = update(weightedPerceptron, instance, binaryTrueLabel);
//          log.info(epochID + instance.id + " New Perceptrion: " + weightedPerceptron + ", Weight: " + weightedPerceptron.weight);
        } else {
          // correct prediction; so increment current perceptron's weight
          weightedPerceptron.weight++;
        }
      }
    }
    Helper.printElapsedTime(startTime);
    Helper.printMemoryUsed();
  }
  
  private WeightedPerceptron update(WeightedPerceptron weightedPerceptron, SparseVector instance, double trueLabel) {
    boolean makeNewPerceptron = (weightedPerceptron.weight >= MIN_PERCEPTRON_WEIGHT);
    double[] oldPerceptron = weightedPerceptron.perceptron;
    double[] newPerceptron = makeNewPerceptron ? new double[oldPerceptron.length] : oldPerceptron;

    int index = 0;
    for (SID featureID : matrix.getColumnIDs()) {
      double v = oldPerceptron[index];
      Cell cell = instance.get(featureID);
      double x = (cell == null) ? 0 : cell.value;
      newPerceptron[index++] = v + trueLabel * x;
    }
    
    if (makeNewPerceptron) {
      weightedPerceptron = new WeightedPerceptron(newPerceptron);
      weightedPerceptrons.add(weightedPerceptron);
    }
    weightedPerceptron.weight = 1;
    return weightedPerceptron;
  }
}