/*******************************************************************************
 * Copyright (c) 2011 Dipanjan Das 
 * Language Technologies Institute, 
 * Carnegie Mellon University, 
 * All Rights Reserved.
 * 
 * LogLogisticRegressionModel.java is part of SEMAFOR 2.0.
 * 
 * SEMAFOR 2.0 is free software: you can redistribute it and/or modify  it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation, either version 3 of the License, or 
 * (at your option) any later version.
 * 
 * SEMAFOR 2.0 is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU General Public License for more details. 
 * 
 * You should have received a copy of the GNU General Public License along
 * with SEMAFOR 2.0.  If not, see <http://www.gnu.org/licenses/>.
 ******************************************************************************/
package edu.cmu.cs.lti.ark.fn.optimization;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Date;
import java.util.ArrayList;
import java.util.Map;
import java.util.Random;

import edu.cmu.cs.lti.ark.fn.optimization.LDouble.IdentityElement;

import riso.numerical.LBFGS;

import gnu.trove.THashMap;
import gnu.trove.TObjectDoubleHashMap;
import gnu.trove.TIntArrayList;
import gnu.trove.TDoubleArrayList;

/**
 * Defines a logistic regression model for determining whether two sentences from 
 * Wikipedia are parallel or not. 
 * 
 * @author Kevin Gimpel
 * @date 8/17/2007
 */
public class LogLogisticRegressionModel extends LogModel{

	ArrayList<TDoubleArrayList> m_trainingData;
	TIntArrayList m_trainingLabels;
	ArrayList<TDoubleArrayList> m_testData;
	TIntArrayList m_testLabels;
	ArrayList<TDoubleArrayList> m_devData;
	TIntArrayList m_devLabels;
	int m_currentTrainingExample = 0;
	int numParams = 0;
	double lambda = 0.001;
	Map<Integer, LogFormula> mLookupChart;
	boolean mReg = false;

	/*
	 * for LBFGS
	 */
	protected int m_max_its = 2000;
	//not sure how this parameter comes into play
	protected double m_eps = 1.0e-5;
	protected double xtol = 1.0e-10; //estimate of machine precision.  get this right
	//number of corrections, between 3 and 7
	//a higher number means more computation and time, but more accuracy, i guess
	protected int m_num_corrections = 3; 
	protected boolean m_debug = true;
	
	public static void main(String[] args)
	{
		LogLogisticRegressionModel m = new LogLogisticRegressionModel("/Users/dipanjand/work/summer2009/FramenetParsing/BPData/traindata.txt","train");
		try {
			//m.runCustomLBFGS("model/testModel.txt");
			m.runTotallyRandomSGA("/Users/dipanjand/work/summer2009/FramenetParsing/BPData/testModelSGA.txt");
			//m.runBatchSGA("model/testModelSGA.txt");
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}	
	
	protected double classify()
	{
		int numCorrect = 0;
		for (int j = 0; j < m_trainingData.size(); j++)
		{
			TDoubleArrayList currDatum = m_trainingData.get(j);
			// classify using current weights
			double pos = generateTestVal(currDatum);
			pos = new LDouble(pos).exponentiate();
			if ((pos >= 0.5 && m_trainingLabels.get(j) == 1) || (pos < 0.5 && m_trainingLabels.get(j) == 0)) {
				numCorrect++;
			}
		}
		double acc = ((double)numCorrect) / ((double)m_trainingData.size());
		System.out.println("Train: " + numCorrect + " / " + m_trainingData.size() + " = " + acc);
		return acc;
	}	
	

	protected double classifyDev()
	{
		int numCorrect = 0;
		for (int j = 0; j < m_devData.size(); j++) {
			TDoubleArrayList currDatum = m_devData.get(j);
			// classify using current weights
			double pos = generateTestVal(currDatum);
			pos = new LDouble(pos).exponentiate();
			if ((pos >= 0.5 && m_devLabels.get(j) == 1) || (pos < 0.5 && m_devLabels.get(j) == -1)) {
				numCorrect++;
			}
		}
		double acc = ((double)numCorrect) / ((double)m_devData.size());
		System.out.println("Dev: " + numCorrect + " / " + m_devData.size() + " = " + acc);
		return acc;
	}	


	protected double classifyTest()
	{
		int numCorrect = 0;
		for (int j = 0; j < m_testData.size(); j++) {
			TDoubleArrayList currDatum = m_testData.get(j);
			// classify using current weights
			double pos = generateTestVal(currDatum);
			pos = new LDouble(pos).exponentiate();
			if ((pos >= 0.5 && m_testLabels.get(j) == 1) || (pos < 0.5 && m_testLabels.get(j) == -1)) {
				numCorrect++;
			}
		}
		double acc = ((double)numCorrect) / ((double)m_testData.size());
		System.out.println("Test: " + numCorrect + " / " + m_testData.size() + " = " + acc);
		return acc;
	}	

	public double classifyRavine(String outputFile)
	{
		setParametersWhileTest(outputFile);
		int correct = 0;
		for (int j = 0; j < m_testData.size(); j++) {
			TDoubleArrayList currDatum = m_testData.get(j);
			// classify using current weights
			double pos = generateTestVal(currDatum);
			pos = new LDouble(pos).exponentiate();
			System.out.println(pos);
			if(pos>=0.5&&m_testLabels.get(j)==1)
				correct++;
			if(pos<0.5&&m_testLabels.get(j)==-1)
				correct++;
		}
		double acc = (double)correct/m_testData.size();
		System.out.println("Accuracy="+acc);
		return 0;
	}	
	
	

	protected LogFormula getNextFormula() {
		m_currentTrainingExample++;
		if(m_currentTrainingExample>=m_trainingData.size())
			return null;
		return generateFormulaForTrainingExample(m_trainingData.get(m_currentTrainingExample-1), m_trainingLabels.get(m_currentTrainingExample-1));
	}

	protected LogFormula getFormula(int index) {
		return generateFormulaForTrainingExample(m_trainingData.get(index), m_trainingLabels.get(index));		
	}
	public int getNumTrainingExamples() {
		return m_trainingData.size();
	}

	private LogFormula generateFormulaForTrainingExample(TDoubleArrayList datum, int label)
	{
		m_current = 0;
		m_llcurrent = 0;	
		LogFormula epower = getFormulaObject(LogFormula.Op.EXP);
		LogFormula featweightsum1 = getFormulaObject(LogFormula.Op.PLUS);
		for (int i = 0; i < datum.size(); i++) {
			LogFormula featweight1 = getFormulaObject(LogFormula.Op.TIMES);
			int paramId = A.getInt("param_"+i);
			LogFormula formula = getLazyLookupFormulaObjectCustom(paramId,"param_"+i);
			featweight1.add_arg(formula);
			featweight1.add_arg(getFormulaObject(LDouble.convertToLogDomain(datum.get(i))));
			featweightsum1.add_arg(featweight1);
		}
		epower.add_arg(featweightsum1);
		LogFormula logpart = getFormulaObject(LogFormula.Op.LOG); 
		LogFormula logsum = getFormulaObject(LogFormula.Op.PLUS);
		logsum.add_arg(getFormulaObject(IdentityElement.TIMES_IDENTITY));
		logsum.add_arg(epower);
		logpart.add_arg(logsum);
		if (label == 1)
		{
			LogFormula ret = getFormulaObject(LogFormula.Op.PLUS);
			LogFormula term2 = getFormulaObject(LogFormula.Op.TIMES);
			term2.add_arg(getFormulaObject(LDouble.convertToLogDomain(-1.0)));			
			term2.add_arg(logpart);
			ret.add_arg(featweightsum1);
			ret.add_arg(term2);
			if(mReg)
			{
				LogFormula ret2 = getFormulaObject(LogFormula.Op.PLUS);
				LogFormula regTerm = getRegularizationTerm();
				ret2.add_arg(ret);
				ret2.add_arg(regTerm);
				return ret2;
			}
			else
			{
				return ret;
			}
		} else {
			LogFormula ret = getFormulaObject(LogFormula.Op.TIMES);
			ret.add_arg(getFormulaObject(LDouble.convertToLogDomain(-1.0)));
			ret.add_arg(logpart);
			if(mReg)
			{
				LogFormula ret2 = getFormulaObject(LogFormula.Op.PLUS);
				LogFormula regTerm = getRegularizationTerm();
				ret2.add_arg(ret);
				ret2.add_arg(regTerm);
				return ret2;
			}
			else
			{
				return ret;
			}
		}
	}
	
	private double generateTestVal(TDoubleArrayList datum)
	{
		m_current = 0;
		m_llcurrent = 0;	
		LogFormula epower = getFormulaObject(LogFormula.Op.EXP);
		LogFormula featweightsum1 = getFormulaObject(LogFormula.Op.PLUS);
		for (int i = 0; i < datum.size(); i++) {
			LogFormula featweight1 = getFormulaObject(LogFormula.Op.TIMES);
			int paramId = A.getInt("param_"+i);
			LogFormula formula = getLazyLookupFormulaObjectCustom(paramId,"param_"+i);
			featweight1.add_arg(formula);
			featweight1.add_arg(getFormulaObject(LDouble.convertToLogDomain(datum.get(i))));
			featweightsum1.add_arg(featweight1);
		}
		epower.add_arg(featweightsum1);
		LogFormula logpart = getFormulaObject(LogFormula.Op.LOG); 
		LogFormula logsum = getFormulaObject(LogFormula.Op.PLUS);
		logsum.add_arg(getFormulaObject(IdentityElement.TIMES_IDENTITY));
		logsum.add_arg(epower);
		logpart.add_arg(logsum);
		LogFormula ret = getFormulaObject(LogFormula.Op.PLUS);
		LogFormula term2 = getFormulaObject(LogFormula.Op.TIMES);
		term2.add_arg(getFormulaObject(LDouble.convertToLogDomain(-1.0)));			
		term2.add_arg(logpart);
		ret.add_arg(featweightsum1);
		ret.add_arg(term2);
		return ret.evaluate(this).exponentiate();
	}


	private LogFormula getRegularizationTerm() {
		// (* -0.5 lambda (w . w))
		LogFormula ret = getFormulaObject(LogFormula.Op.TIMES);

		// -0.5
		LogFormula term1 = getFormulaObject(LDouble.convertToLogDomain(-1.0));

		// lambda
		LogFormula term2 = getFormulaObject(LDouble.convertToLogDomain(lambda));

		// w . w
		LogFormula featweightsum = getFormulaObject(LogFormula.Op.PLUS);
		for (int i = 0; i < numParams; i++) {
			LogFormula featweight = getFormulaObject(LogFormula.Op.TIMES);
			int paramId = A.getInt("param_"+i);
			LogFormula formula = getLazyLookupFormulaObjectCustom(paramId,"param_"+i);
			featweight.add_arg(formula);
			featweight.add_arg(formula);
			featweightsum.add_arg(featweight);
		}
		ret.add_arg(term1);
		ret.add_arg(term2);
		ret.add_arg(featweightsum);
		return ret;
	}


	public LogFormula getLazyLookupFormulaObjectCustom(int ind, String name) {
		LogFormula f;
		f = checkLookupChart(ind);
		if(f == null)
			f = addToLookUpChart(ind,name);
		return f;
	}
	
	public LogFormula checkLookupChart(Integer ind)
	{
		return mLookupChart.get(ind);
	}

	public LogFormula addToLookUpChart(int ind, String name)
	{
		LogFormula f = new LazyLookupLogFormula(ind, name);
		mLookupChart.put(ind, f);
		return f;
	}	

	/***********************************************************************************************/
	/***********************************************************************************************/
	/***************************** Constructors and Initialization Code ****************************/
	/***********************************************************************************************/
	/***********************************************************************************************/
	/**
	 * Constructor which takes filenames containing 
	 * @param phrasePairFilename File containing phrase pairs (with alignments) for defining support of other parameters
	 */
	public LogLogisticRegressionModel(String xfile, String xdevfile, String xtestfile) {
		initializeParameterIndexes();
		loadTrainingData(xfile, m_trainingData, m_trainingLabels);
		loadTrainingData(xtestfile, m_testData, m_testLabels);
		loadTrainingData(xdevfile, m_devData, m_devLabels);
	}

	public LogLogisticRegressionModel(String xtestfile, String testOrTrain) {
		initializeParameterIndexes();
		if(testOrTrain.equals("test"))
			loadTrainingData(xtestfile, m_testData, m_testLabels);
		else
			loadTrainingData(xtestfile, m_trainingData, m_trainingLabels);
	}
	
	
	protected void printInfo(String modelName, String file1, String file2, String labels) {
		System.out.println("Initialized " + modelName + " from files: " + file1);
		System.out.println("Total number of strings stored: " + A.getNumEntries());
	}

	protected void initializeParameterIndexes() {
		A = new Alphabet();
		V = new LDouble[PARAMETER_TABLE_INITIAL_CAPACITY];
		G = new LDouble[PARAMETER_TABLE_INITIAL_CAPACITY];
		m_trainingData = new ArrayList<TDoubleArrayList>(1000);
		m_trainingLabels = new TIntArrayList(1000);
		m_testData = new ArrayList<TDoubleArrayList>(100);
		m_testLabels = new TIntArrayList(100);
		m_devData = new ArrayList<TDoubleArrayList>(100);
		m_devLabels = new TIntArrayList(100);
		savedValues = new TObjectDoubleHashMap<String>(1000);
		m_savedFormulas = new ArrayList<LogFormula>(FORMULA_LIST_INITIAL_CAPACITY);
		m_current = 0;
		m_savedLLFormulas = new ArrayList<LazyLookupLogFormula>(LLFORMULA_LIST_INITIAL_CAPACITY);
		m_llcurrent = 0;
		mLookupChart = new THashMap<Integer,LogFormula>(PARAMETER_TABLE_INITIAL_CAPACITY);
	}
	protected void initializeParameter(int paramIndex)
	{
		setValue(paramIndex, new LDouble(1.0));
		setGradient(paramIndex, new LDouble(LDouble.IdentityElement.PLUS_IDENTITY));		
	}

	private void loadTrainingData(String datafile, ArrayList<TDoubleArrayList> input, TIntArrayList output) {
		BufferedReader dataFileReader;
		String line;
		try {
			dataFileReader = new BufferedReader(new FileReader(datafile));
			TDoubleArrayList lineList;
			String[] toks;
			/*
			 * read in each line of the file and process it
			 */
			while ((line = dataFileReader.readLine()) != null) {
				lineList = new TDoubleArrayList();
				/*
				 * tokenize the line into features
				 */
				toks = line.split("\\s");
				output.add(Integer.parseInt(toks[0].trim()));
				int j = 0;
				for (int i = 1; i < toks.length; i++)
				{
					String token = toks[i].trim();
					lineList.add(Double.parseDouble(token));
					if (j >= numParams) {//maxNumFeatures) {
						int paramId = A.getInt("param_"+j);
						initializeParameter(paramId);
						numParams++;
					}
					j++;
				}
				lineList.add(1.0);
				if (j >= numParams) {//maxNumFeatures) {
					int paramId = A.getInt("param_"+j);
					initializeParameter(paramId);
					numParams++;
				}
				input.add(lineList);				
			}
			if (input.size() != output.size()) {
				System.out.println("Differing numbers of input and output lines.");
				System.out.println("Input: " + input.size() + " lines");
				System.out.println("Output: " + output.size() + " lines");
			}
			System.out.println("Num params: " + numParams);
		} catch (Exception exc) {
			exc.printStackTrace();			
		}
	}

	public void saveModel(String modelFile) {
		PrintWriter outputWriter;
		try {
			outputWriter = new PrintWriter(new FileOutputStream(modelFile));
			for (int i = 1; i <= numParams; i++) {
				outputWriter.println(getValue(i).getValue());
			}
			outputWriter.flush();
			outputWriter.close();
		} catch (Exception e) {
			System.out.println(e.toString());
			e.printStackTrace();			
		}
	}

	public void printGradients()
	{
		for(int i = 0; i < numParams; i ++)
		{
			System.out.println(G[i+1].exponentiate());
		}
		System.out.println();
	}


	public void printValues()
	{
		for(int i = 0; i < numParams; i ++)
		{
			System.out.println(V[i+1].exponentiate());
		}
		System.out.println();
	}


	public double[] getGradient(boolean op,boolean maximize,LDouble l_value)
	{
		int numTrainingExamples = getNumTrainingExamples();
		for(int i = 0; i < numParams; i ++)
		{
			G[i+1].reset(IdentityElement.PLUS_IDENTITY);
		}
		//times
		if(op)
		{
			
		}
		//plus
		else
		{
			for(int i = 0; i < numTrainingExamples; i ++)
			{
				LogFormula f = getFormula(i);
				f.backprop(this, new LDouble(LDouble.IdentityElement.TIMES_IDENTITY));
			}
		}
		double[] g = new double[numParams];
		
		for(int i = 0; i < numParams; i ++)
		{
			g[i] = transformGradient(i,maximize);
		}		
		return g;
	}
	
	public double transformGradient(int i,boolean maximize)
	{
		double factor = 1.0;
		if(maximize)
			factor = -1.0;
		return G[i+1].exponentiate()*factor;
	}

	public double[] getValues()
	{
		double[] values = new double[numParams];
		for(int i = 0; i < numParams; i ++)
		{
			values[i] = V[i+1].exponentiate();
		}
		return values;
	}
	
	public void setValues(double[] values)
	{
		for(int i = 0; i < numParams; i ++)
		{
			V[i+1] = LDouble.convertToLogDomain(values[i]);
		}
	}	

	public void setValuesDashed(double[] values)
	{
		for(int i = 0; i < numParams; i ++)
		{
			V[i+1] = new LDouble(values[i]);
		}
	}

	public void saveParameters(String outputFile)
	{
		try
		{
			BufferedWriter bWriter = new BufferedWriter(new FileWriter(outputFile));
			for(int i = 0; i < numParams; i ++)
			{
				String paramName = "param_"+i;
				bWriter.write(paramName+"\t"+V[i+1].getValue()+"\t"+V[i+1].isPositive()+"\n");
			}			
			bWriter.close();
		}
		catch (IOException e)
		{
			e.printStackTrace();
		}
	}
	
	public void setParametersWhileTest(String paramFile)
	{
		try
		{
			BufferedReader bReader = new BufferedReader(new FileReader(paramFile));
			String line = null;
			int count = 0;
			while((line=bReader.readLine())!=null)
			{
				String[] arr = line.trim().split("\t");
				V[count+1]=new LDouble(new Double(arr[1]),new Boolean(arr[2]).booleanValue());
				count++;
			}
			bReader.close();
		}
		catch(IOException e)
		{
			
		}
	}	
	
	public double[] runBatchSGA(String paramFile)
	{
		double[] m_estimate = new double[numParams];
		int trainingSize = m_trainingData.size();
		int countPasses = 0;
		int maxPasses = 1000;
		
		boolean maximize = true;
		LogFormula.Op Type = LogFormula.Op.PLUS;
		RootLogFormula fullLogLike = new RootLogFormula(this,Type,"lazyroot");
		while(countPasses<maxPasses)
		{
			for(int index = 0; index < trainingSize; index++)
			{
				for(int i = 0; i < numParams; i ++)
				{
					G[i+1].reset(IdentityElement.PLUS_IDENTITY);
				}
				LogFormula f = this.getFormula(index);
				f.backprop(this, new LDouble(LDouble.IdentityElement.TIMES_IDENTITY));
				double[] g = new double[numParams];
				for(int i = 0; i < numParams; i ++)
				{
					g[i] = transformGradient(i,maximize);
				}
				double[] v = getValues();
				v = SGA.updateGradient(v, g);	
				setValues(v);
			}
			countPasses++;
			LDouble l_value = fullLogLike.evaluate(this);
			double m_value = this.extractFunctionValueForLBFGS(l_value, maximize);
			System.out.println("Function value:"+m_value+"\n");
			fullLogLike.changedParamValues();
			saveParameters(paramFile);
			classify();			
		}		
		saveParameters(paramFile);
		classify();
		return m_estimate;
	}
	
	
	public double[] runTotallyRandomSGA(String paramFile)
	{
		double[] m_estimate = new double[numParams];
		int trainingSize = m_trainingData.size();
		int countUpdates = 0;
		int maxUpdates = trainingSize*1000;
		Date d = new Date();
		Random rand = new Random(d.getTime());
		boolean maximize = true;
		LogFormula.Op Type = LogFormula.Op.PLUS;
		RootLogFormula fullLogLike = new RootLogFormula(this,Type,"lazyroot");
		double oldValue = Double.MIN_VALUE;
		double STOPPING_CRITERION=0.001;
		while(countUpdates<maxUpdates)
		{
			if(countUpdates%100==0)
			{
				LDouble l_value = fullLogLike.evaluate(this);
				double m_value = this.extractFunctionValueForLBFGS(l_value, maximize);
				if(Math.abs(m_value-oldValue)<STOPPING_CRITERION)
				{
					break;
				}
				System.out.println("Function value:"+m_value+"\n");
				fullLogLike.changedParamValues();
				saveParameters(paramFile);
				classify();
				oldValue=m_value;
			}
			int index = rand.nextInt(trainingSize);
			for(int i = 0; i < numParams; i ++)
			{
				G[i+1].reset(IdentityElement.PLUS_IDENTITY);
			}
			LogFormula f = this.getFormula(index);
			f.backprop(this, new LDouble(LDouble.IdentityElement.TIMES_IDENTITY));
			double[] g = new double[numParams];
			for(int i = 0; i < numParams; i ++)
			{
				g[i] = transformGradient(i,maximize);
			}
			double[] v = getValues();
			v = SGA.updateGradient(v, g);
			setValues(v);
			countUpdates++;
			m_estimate=v;
		}		
		saveParameters(paramFile);
		classify();
		return m_estimate;
	}
	
	public double[] runCustomLBFGS (String paramFile) throws Exception
	{    
		LogFormula.Op Type = LogFormula.Op.PLUS;
		RootLogFormula fullLogLike = new RootLogFormula(this,Type,"lazyroot");		
		double[] diagco = new double[numParams];
		int[] iprint = new int[2];
		iprint[0] = m_debug?1:-1;  //output at every iteration (0 for 1st and last, -1 for never)
		iprint[1] = 0; //output the minimum level of info
		int[] iflag = new int[1];
		iflag[0] = 0;
		double[] gradient = new double[numParams];
		double[] m_estimate = new double[numParams];
		int iteration = 0;
		boolean maximize = true;
		do {
			LDouble l_value = fullLogLike.evaluate(this);
			double m_value = this.extractFunctionValueForLBFGS(l_value, maximize);
			System.out.println("Function value:"+m_value);
			gradient = getGradient(false,maximize,l_value);
			m_estimate = getValues();	
			LBFGS.lbfgs(numParams,
					m_num_corrections, 
					m_estimate, 
					m_value,
					gradient, 
					false, //true if we're providing the diag of cov matrix Hk0 (?)
			diagco, //the cov matrix
			iprint, //type of output generated
			m_eps,
			xtol, //estimate of machine precision
			iflag //i don't get what this is about
			);
			setValues(m_estimate);
			iteration++;
			fullLogLike.changedParamValues();
			m_currentTrainingExample=0;
			if(iteration%20==0)
			{	classify();
			}
		} while (iteration <= m_max_its&&iflag[0] != 0);
		saveParameters(paramFile);
		classify();
		return m_estimate;
	}

}