package RobustHMM;
/*
 * Copyright (C) 2019  Evan Tarbell and Tao Liu

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

import net.sf.javaml.core.DenseInstance;

import org.apache.commons.math3.stat.descriptive.moment.Mean;
import org.apache.commons.math3.stat.descriptive.moment.Variance;

import be.ac.ulg.montefiore.run.jahmm.Hmm;
import be.ac.ulg.montefiore.run.jahmm.ObservationReal;
import be.ac.ulg.montefiore.run.jahmm.ObservationVector;
import be.ac.ulg.montefiore.run.jahmm.OpdfGaussian;
import be.ac.ulg.montefiore.run.jahmm.OpdfMultiGaussian;
import be.ac.ulg.montefiore.run.jahmm.io.HmmBinaryWriter;

public class RandomInitHMM {
	//class variables
	private String input;
	private int numStates;
	private Hmm<?> hmm;
	
	//Main method variables
	private static int states = 0;
	private static String file = null;
	private static String output = null;
	
	public RandomInitHMM(String i,int n) throws FileNotFoundException{
		input = i;
		numStates = n;
		read();
	}
	public Hmm<?> getHMM(){return hmm;}

	private void read() throws FileNotFoundException{
		ObservationReader reader = new ObservationReader(input);
		String method = reader.getMethod();
		if (method.equals("Vector")){
			hmm = readVector(reader);
		}
		else if(method.equals("Real") ){
			hmm = readReal(reader);
		}
		
	}
	private void read2() throws FileNotFoundException{
		Scanner inFile =new Scanner ((Readable) new FileReader(input));
		ArrayList<double[]> data = new ArrayList<double[]>();
		int numFeat = 0;
		while (inFile.hasNext()){
			String line = inFile.nextLine();
			String[] features = line.split(",");
			double[] values = new double[features.length];
			numFeat = features.length;
			for (int i = 0;i < features.length;i++){
				values[i] = Double.parseDouble(features[i]);
			}
			data.add(values);
		}
		double[][] values = new double[data.size()][numFeat];
		for(int i = 0;i < data.size();i++){
			double[] temp = data.get(i);
			for (int a = 0;a < temp.length;a++){
				values[i][a] = temp[a];
			}
		}
		data = null;
		double[] mu = new double[numFeat];
		double[] var = new double[numFeat];
		Mean mean = new Mean();
		Variance variance = new Variance();
		for (int i = 0; i < numFeat;i++){
			mu[i] = mean.evaluate(values[i]);
			var[i] = variance.evaluate(values[i]);
		}
		List<OpdfMultiGaussian> opdf = new ArrayList<OpdfMultiGaussian>();
		for (int i = 0;i < numStates;i++){
			
		}
	}
	
	@SuppressWarnings({"unchecked"})
	private Hmm<ObservationVector> readVector(ObservationReader reader){
		
		ArrayList<ObservationVector> obs = (ArrayList<ObservationVector>) reader.getObs();
		double[] initial = setInitial();
		double[][] trans = setTrans();
		List<OpdfMultiGaussian> opdf = setVectorPDF(obs);
		
		Hmm<ObservationVector> h = new Hmm<ObservationVector>(initial, trans, opdf);
		
		
		return h;
		
	}
	@SuppressWarnings({ "unused", "unchecked" })
	private Hmm<ObservationReal> readReal(ObservationReader reader){
		ArrayList<ObservationReal> obs = (ArrayList<ObservationReal>) reader.getObs();
		double[] initial = setInitial();
		double[][] trans = setTrans();
		//TODO: write opdf maker for monovariate gaussian
		return null;
	}
	private List<OpdfGaussian> setRealPDF(ArrayList<ObservationReal> obs){
		//TODO: write this method
		return null;
	}
	
	private List<OpdfMultiGaussian> setVectorPDF(ArrayList<ObservationVector> obs){
		List<OpdfMultiGaussian> opdf = new ArrayList<OpdfMultiGaussian>();
		ObservationVector o = obs.get(0);
		int dim = o.dimension();
		double[] means = new double[dim];
		Mean mu = new Mean();
		Variance v = new Variance();
		double[] vars = new double[dim];
		double[] vals = new double[obs.size()];
		for (int i = 0;i < dim;i++){
			
			for (int a = 0;a < obs.size();a++){
				vals[a] = obs.get(a).value(i);
			}
			means[i] = mu.evaluate(vals, 0, vals.length);
			vars[i] = v.evaluate(vals);
		}
		double[][] Means = new double[numStates][dim];
		double[][] Vars = new double[numStates][dim];
		for (int i = 0; i < means.length;i++){
			double mean = means[i];
			double var = vars[i];
			
			double sd = Math.sqrt(var);
			double meanLower = mean - (3.0*sd);
			double meanUpper = mean + (3.0*sd);
			double varLower = 0.5 * var;
			double varUpper = 3 * var;
			
			
			double meanStep = (meanUpper - meanLower) / ((double)numStates - 1);
			double varStep = (varUpper - varLower) / ((double) numStates - 1);
			//System.out.println((meanUpper-meanLower)+"\t"+meanStep);
			for(int a = 0;a < numStates;a++){
				Means[a][i] = meanLower + (a * meanStep);
				Vars[a][i] = varLower + (a * varStep);
				//System.out.println(Means[a][i]);
				//System.out.println(Vars[a][i]);
			}
			
		}
		for (int i = 0;i < Means.length;i++){
			//System.out.println(Means.length);
			double[][] cov = new double[dim][dim];
			for (int a = 0;a < Means[i].length;a++){
				//System.out.println(Means[i].length);
				cov[a][a] = Vars[i][a];
				//System.out.println(cov[a][a]);
			}
			OpdfMultiGaussian pdf = new OpdfMultiGaussian(Means[i],cov);
			opdf.add(pdf);
		}
		
		return opdf;
		
	}
	private double[][] setTrans(){
		double[][] trans = new double[numStates][numStates];
		for (int i = 0;i < numStates;i++){
			for(int a = 0;a < numStates;a++){
				trans[i][a] = (double) 1/numStates;
			}
		}
		return trans;
	}
	private double[] setInitial(){
		double[] initial = new double[numStates];
		for(int i = 0;i < numStates;i++){
			initial[i] = (double) 1/numStates;
		}
		return initial;
	}
	
	public static void main(String[] args) throws IOException{
		for (int i = 0; i < args.length; i++) {

			switch (args[i].charAt((1))) {
			
			
			case'i':
				file = (args[i+1]);
				i++;
				break;
			case 'n':
				states = Integer.parseInt(args[i+1]);
				i++;
				break;
			case'o':
				output = args[i+1];
				i++;
				break;
			
			}
		}
		if (file == null || states == 0 || output == null){
			printUsage();
			System.exit(1);
		}
		
		RandomInitHMM init = new RandomInitHMM(file,states);
		Hmm<?> hmm = init.getHMM();
		System.out.println(hmm.toString());
		FileOutputStream out = new FileOutputStream(output);
		HmmBinaryWriter writer = new HmmBinaryWriter();
		writer.write(out, hmm);
		
	}
	private static void printUsage(){
		System.out.println("Usage: java -jar RandomInitHMM.jar");
		System.out.println("Required Parameters:");
		System.out.println("-i <File> Observation File in proper format");
		System.out.println("-n <int> Number of States");
		System.out.println("-o <File> Output file to write binary HMM");
	}
}