package vn.vitk.tag;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import vn.vitk.lang.CorpusPack;
import vn.vitk.lang.Language;
import vn.vitk.util.SparkContextFactory;

/**
 * @author Phuong LE-HONG
 * <p>
 * May 16, 2016, 9:40:34 AM
 * <p>
 * A part-of-speech tagger which uses the conditional Markov model. 
 *   
 */
public class Tagger implements Serializable {
	
	private static final long serialVersionUID = 8061440373376898771L;
	private transient JavaSparkContext jsc;
	private CMMModel cmmModel;
	
	private boolean verbose = false;
	
	public enum OutputFormat {
		TEXT,
		JSON,
		PARQUET
	}
	
	enum TaggerMode {
		TRAIN,
		TAG
	}
	
	
	/**
	 * Creates a tagger which runs on a stand-alone machine.
	 */
	public Tagger(JavaSparkContext jsc) {
		this.jsc = jsc;
	}
	
	/**
	 * Trains a tagger. The <code>inputFileName</code> contains tagged sentences
	 * in a simple format: "He/P kicks/V a/D ball/N ./.". After training, the
	 * tagger model is saved to a file <code>modelFileName</code>. The
	 * parameters used in training is given in an argument.
	 * 
	 * @param inputFileName
	 * @param modelFileName
	 * @param params
	 * @return a {@link CMMModel}
	 */
	public CMMModel train(String inputFileName, String modelFileName, CMMParams params) {
		JavaRDD<String> lines = jsc.textFile(inputFileName);
		cmmModel = train(lines.collect(), modelFileName, params);
		return cmmModel;
	}
	
	/**
	 * Trains a tagger with data specified in a data frame. The data frame has 
	 * two columns, one column "sentence" contains a word sequence, and the other column "partOfSpeech" 
	 * contains the corresponding tag sequence. Each row of the data frame specifies a tagged sequence
	 * in the training set.
	 * @param dataset
	 * @param modelFileName
	 * @param params
	 * @return a {@link CMMModel}
	 */
	public CMMModel train(DataFrame dataset, String modelFileName, CMMParams params) {
		CMM cmm = new CMM(params).setVerbose(verbose);
		cmmModel = cmm.fit(dataset);
		try {
			cmmModel.write().overwrite().save(modelFileName);
		} catch (IOException e) {
			e.printStackTrace();
		}
		return cmmModel;
	}
	
	/**
	 * Trains a tagger. Training data are tagged sequences stored in an input
	 * file of a simple format, each sequence in a line. After training, the
	 * tagger is saved to a file. The parameters used in training is given in 
	 * an argument.
	 * 
	 * @param taggedSentences
	 * @param modelFileName
	 * @param params
	 * @return a {@link CMMModel}
	 */
	public CMMModel train(List<String> taggedSentences, String modelFileName, CMMParams params) {
		DataFrame dataset = createDataFrame(taggedSentences);
		return train(dataset, modelFileName, params);
	}

	/**
	 * Creates a data frame from a list of tagged sentences.
	 * @param taggedSentences
	 * @return a data frame of two columns: "sentence" and "partOfSpeech".
	 */
	public DataFrame createDataFrame(List<String> taggedSentences) {
		List<String> wordSequences = new LinkedList<String>();
		List<String> tagSequences = new LinkedList<String>();
		for (String taggedSentence : taggedSentences) {
			StringBuilder wordBuf = new StringBuilder();
			StringBuilder tagBuf = new StringBuilder();
			String[] tokens = taggedSentence.split("\\s+");
			for (String token : tokens) {
				String[] parts = token.split("/");
				if (parts.length == 2) {
					wordBuf.append(parts[0]);
					wordBuf.append(' ');
					tagBuf.append(parts[1]);
					tagBuf.append(' ');
				} else { // this token is "///"  
					wordBuf.append('/');
					wordBuf.append(' ');
					tagBuf.append('/');
					tagBuf.append(' ');
				}
			}
			wordSequences.add(wordBuf.toString().trim());
			tagSequences.add(tagBuf.toString().trim());
		}
		if (verbose) {
			System.out.println("Number of sentences = " + wordSequences.size());
		}
		List<Row> rows = new LinkedList<Row>();
		for (int i = 0; i < wordSequences.size(); i++) {
			rows.add(RowFactory.create(wordSequences.get(i), tagSequences.get(i)));
		}
		JavaRDD<Row> jrdd = jsc.parallelize(rows);
		StructType schema = new StructType(new StructField[]{
				new StructField("sentence", DataTypes.StringType, false, Metadata.empty()),
				new StructField("partOfSpeech", DataTypes.StringType, false, Metadata.empty())
			});
			
		return new SQLContext(jsc).createDataFrame(jrdd, schema);
	}
	
	/**
	 * Loads a {@link CMMModel} from a model file.
	 * @param modelFileName
	 * @return this object.
	 */
	public Tagger load(String modelFileName) {
		this.cmmModel = CMMModel.load(modelFileName);
		return this;
	}
	
	
	/**
	 * Tags a list of sequences and returns a list of tag sequences.
	 * @param sentences
	 * @return a list of tagged sequences.
	 */
	public List<String> tag(List<String> sentences) {
		List<Row> rows = new LinkedList<Row>();
		for (String sentence : sentences) {
			rows.add(RowFactory.create(sentence));
		}
		StructType schema = new StructType(new StructField[]{
			new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
		});
		SQLContext sqlContext = new SQLContext(jsc);
		DataFrame input = sqlContext.createDataFrame(rows, schema);
		if (cmmModel != null) {
			DataFrame output = cmmModel.transform(input).repartition(1);
			return output.javaRDD().map(new RowToStringFunction(1)).collect();
		} else {
			System.err.println("Tagging model is null. You need to create or load a model first.");
			return null;
		}
	}
	
	/**
	 * Tags a data frame containing a column named 'sentence'.
	 * @param input
	 * @param outputFileName
	 * @param outputFormat
	 */
	public void tag(DataFrame input, String outputFileName, OutputFormat outputFormat) {
		long tic = System.currentTimeMillis();
		long duration = 0;
		if (cmmModel != null) {
			DataFrame output = cmmModel.transform(input).repartition(1);
			duration = System.currentTimeMillis() - tic;
			switch (outputFormat) {
			case JSON:
				output.write().json(outputFileName);
				break;
			case PARQUET:
				output.write().parquet(outputFileName);
				break;
			case TEXT:
				toTaggedSentence(output).repartition(1).saveAsTextFile(outputFileName);
//				output.select("prediction").write().text(outputFileName);
				break;
			}
		} else {
			System.err.println("Tagging model is null. You need to create or load a model first.");
		}
		if (verbose) {
			long n = input.count();
			System.out.println(" Number of sentences = " + n);
			System.out.println("  Total tagging time = " + duration + " milliseconds.");
			System.out.println("Average tagging time = " + ((float)duration) / n + " milliseconds.");
		}
	}
	
	/**
	 * Tags a list of sequences and writes the result to an output file with a
	 * desired output format.
	 * 
	 * @param sentences
	 * @param outputFileName
	 * @param outputFormat
	 */
	public void tag(List<String> sentences, String outputFileName, OutputFormat outputFormat) {
		List<Row> rows = new LinkedList<Row>();
		for (String sentence : sentences) {
			rows.add(RowFactory.create(sentence));
		}
		StructType schema = new StructType(new StructField[]{
			new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
		});
		SQLContext sqlContext = new SQLContext(jsc);
		DataFrame input = sqlContext.createDataFrame(rows, schema);
		tag(input, outputFileName, outputFormat);
	}
	
	/**
	 * Tags a distributed list of sentences and writes the result to an output file with 
	 * a desired output format.
	 * @param sentences
	 * @param outputFileName
	 * @param outputFormat
	 */
	public void tag(JavaRDD<Row> sentences, String outputFileName, OutputFormat outputFormat) {
		StructType schema = new StructType(new StructField[]{
			new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
		});
		SQLContext sqlContext = new SQLContext(jsc);
		DataFrame input = sqlContext.createDataFrame(sentences, schema);
		tag(input, outputFileName, outputFormat);
	}
	
	/**
	 * Tags a text file, each sentence in a line and writes the result to an output file 
	 * with a desired output format.
	 * @param inputFileName
	 * @param outputFileName
	 * @param outputFormat
	 */
	public void tag(String inputFileName, String outputFileName, OutputFormat outputFormat) {
		List<String> sentences = jsc.textFile(inputFileName).collect();
		tag(sentences, outputFileName, outputFormat);
	}
	
	/**
	 * Tags a text file, each sentence in a line and writes the result to the console.
	 * @param inputFileName
	 */
	public void tag(String inputFileName) {
		List<String> sentences = jsc.textFile(inputFileName).collect();
		List<String> output = tag(sentences);
		for (int i = 0; i < sentences.size(); i++) {
			StringBuilder sb = new StringBuilder(64);
			String words[] = sentences.get(i).split("\\s+");
			String tags[] = output.get(i).split("\\s+");
			for (int j = 0; j < words.length; j++) {
				sb.append(words[j]);
				sb.append('/');
				sb.append(tags[j]);
				sb.append(' ');
			}
			System.out.println(sb.toString().trim());
		}
	}
	
	private JavaRDD<String> toTaggedSentence(DataFrame output) {
		return output.javaRDD().map(new Function<Row, String>() {
			private static final long serialVersionUID = 4208643510231783579L;
			@Override
			public String call(Row row) throws Exception {
				String[] tokens = row.getString(0).trim().split("\\s+");
				String[] tags = row.getString(1).trim().split("\\s+");
				if (tokens.length != tags.length) {
					System.err.println("Incompatible lengths!");
					return null;
				}
				StringBuilder sb = new StringBuilder(64);
				for (int j = 0; j < tokens.length; j++) {
					sb.append(tokens[j]);
					sb.append('/');
					sb.append(tags[j]);
					sb.append(' ');
				}
				return sb.toString().trim();
			}
		});
	}
	
	/**
	 * Evaluates the accuracy of a CMM model on a data frame on tagged sentences.
	 * @param dataset
	 * @return evaluation measures.
	 */
	public float[] evaluate(DataFrame dataset) {
		List<String> correctSequences = dataset.javaRDD().map(new RowToStringFunction(1)).collect();
		long beginTime = System.currentTimeMillis();
		DataFrame output = cmmModel.transform(dataset);
		long endTime = System.currentTimeMillis();
		if (verbose) {
			System.out.println(" Number of sentences = " + correctSequences.size());
			long duration = (endTime - beginTime);
			System.out.println("  Total tagging time = " + duration + " ms.");
			System.out.println("Average tagging time = " + ((float)duration) / correctSequences.size() + " ms.");
		}
		List<String> automaticSequences = output.javaRDD().map(new RowToStringFunction(1)).collect();
		return Evaluator.evaluate(automaticSequences, correctSequences);
	}
	
	/**
	 * Evaluates the tagger on a manually tagged data file.
	 * @param inputFileName
	 * @return scores
	 */
	public float[] evaluate(String inputFileName) {
		JavaRDD<String> lines = jsc.textFile(inputFileName);
		DataFrame dataset = createDataFrame(lines.collect());
		return evaluate(dataset);
	}
	
	/**
	 * Evaluates the accuracy of a CMM model on a list of tagged sentences. 
	 * @param taggedSentences
	 * @return evaluation measures.
	 */
	public float[] evaluate(List<String> taggedSentences) {
		DataFrame dataset = createDataFrame(taggedSentences);
		return evaluate(dataset);
	}
	
	private class RowToStringFunction implements Function<Row, String> {
		private static final long serialVersionUID = -2245906041132281238L;
		private int columnIndex = 0;
		
		RowToStringFunction(int columnIndex) {
			this.columnIndex = columnIndex;
		}
		@Override
		public String call(Row row) throws Exception {
			return row.getString(columnIndex);
		}
	}
	
	/**
	 * Set the verbose mode.
	 * @param verbose
	 * @return this object
	 */
	public Tagger setVerbose(boolean verbose) {
		this.verbose = verbose;
		return this;
	}

	void test(String modelFileName, TaggerMode mode) {
		String[] taggedSentences = {
				"tôi/P ăn/V quả/Nc chuối/N to/A ./.",
				"tôi/Np đá/V quả/Nc bóng/N ./.",
				"tôi/Np ăn/V quả/Nc bóng/N to/A ./.",
				"vải/N hoa/N bụng/N cóc/N",
				"tôi/P cóc/R sợ/A" 
		};
		if (mode == TaggerMode.TRAIN) {
			train(Arrays.asList(taggedSentences), modelFileName, 
					new CMMParams().setNumFeatures(30));
		} else if (mode == TaggerMode.TAG){
			load(modelFileName);
			String[] sentences = {
				"tôi ăn quả chuối to .",
				"tôi đá quả bóng .",
				"tôi ăn quả bóng to .", 
				"vải hoa bụng cóc",
				"tôi cóc sợ" 
			};
			tag(Arrays.asList(sentences), "dat/tag/out", OutputFormat.JSON);
			evaluate(Arrays.asList(taggedSentences));
		}
	}
	
	void test(String inputFileName, int numFeatures, String modelFileName, TaggerMode mode) {
		if (mode == TaggerMode.TRAIN) {
			CMMParams params = new CMMParams()
				.setMaxIter(600)
				.setNumFeatures(numFeatures);
			train(inputFileName, modelFileName, params);
		}
		else if (mode == TaggerMode.TAG) {
			load(modelFileName);
			List<String> taggedSentences = jsc.textFile(inputFileName).collect();
			evaluate(taggedSentences);
		}
	}
	
	void testRandomSplit(String inputFileName, int numFeatures, String modelFileName) {
		CMMParams params = new CMMParams()
			.setMaxIter(600)
			.setRegParam(1E-6)
			.setMarkovOrder(2)
			.setNumFeatures(numFeatures);
		
		JavaRDD<String> lines = jsc.textFile(inputFileName);
		DataFrame dataset = createDataFrame(lines.collect());
		DataFrame[] splits = dataset.randomSplit(new double[]{0.9, 0.1}); 
		DataFrame trainingData = splits[0];
		System.out.println("Number of training sequences = " + trainingData.count());
		DataFrame testData = splits[1];
		System.out.println("Number of test sequences = " + testData.count());
		// train and save a model on the training data
		cmmModel = train(trainingData, modelFileName, params);
		// test the model on the test data
		System.out.println("Test accuracy:");
		evaluate(testData); 
		// test the model on the training data
		System.out.println("Training accuracy:");
		evaluate(trainingData);
	}
	
	/**
	 * For internal test only.
	 * @param args
	 */
	public static void main(String[] args) {
		CorpusPack cp = new CorpusPack(Language.VIETNAMESE);
		String modelFileName = cp.taggerModelFileName();
		
		JavaSparkContext jsc = SparkContextFactory.create();
		Tagger tagger = new Tagger(jsc).setVerbose(true);
		// 1. Toy dataset
		tagger.test(modelFileName, TaggerMode.TRAIN);
		tagger.test(modelFileName, TaggerMode.TAG);
		
		// 2. VTB, randomSplit (90%, 10%)
//		String corpusFileName = cp.taggerCorpusFileName();
//		tagger.testRandomSplit(corpusFileName, 160000, modelFileName);
	}
}