package vn.vitk.tag;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.DefaultParamsReader;
import org.apache.spark.ml.util.DefaultParamsWriter;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.ml.util.SchemaUtils;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.mutable.WrappedArray;
import vn.vitk.util.Constants;

/**
 * @author Phuong LE-HONG
 *         <p>
 *         May 9, 2016, 5:52:38 PM
 *         <p>
 *         A conditional Markov model (or Maximum-Entropy Markov model -- MEMM)
 *         for sequence labeling which is fitted by a CMM graphical model.
 * 
 */
public class CMMModel extends Model<CMMModel> implements MLWritable {

	private static final long serialVersionUID = -4855076361905836432L;
	
	private PipelineModel pipelineModel;
	private ContextExtractor contextExtractor;
	private final Vector weights;
	private final String[] tags;
	private final Map<String, Integer> featureMap;
	private final Map<String, Set<Integer>> tagDictionary; 
	
	/**
	 * Creates a conditional Markov model.
	 * @param pipelineModel
	 * @param weights
	 * @param markovOrder
	 */
	public CMMModel(PipelineModel pipelineModel, Vector weights, MarkovOrder markovOrder, Map<String, Set<Integer>> tagDictionary) {
		this.pipelineModel = pipelineModel;
		this.contextExtractor = new ContextExtractor(markovOrder, Constants.REGEXP_FILE);
		this.weights = weights;
		this.tags = ((StringIndexerModel)(pipelineModel.stages()[2])).labels();
		String[] features = ((CountVectorizerModel)(pipelineModel.stages()[1])).vocabulary();
		featureMap = new HashMap<String, Integer>();
		for (int j = 0; j < features.length; j++) {
			featureMap.put(features[j], j);
		}
		this.tagDictionary = tagDictionary;
	}
	
	@Override
	public CMMModel copy(ParamMap extra) {
		return defaultCopy(extra);
	}

	/**
	 * An immutable unique ID for the object and its derivatives.
	 * @return an immutable unique ID for the object and its derivatives.
	 */
	@Override
	public String uid() {
		String ruid = UUID.randomUUID().toString();
		int n = ruid.length();
		return "cmmModel" + "_" + ruid.substring(n-12, n);
	}

	/* (non-Javadoc)
	 * @see org.apache.spark.ml.Transformer#transform(org.apache.spark.sql.DataFrame)
	 */
	@Override
	public DataFrame transform(DataFrame dataset) {
		JavaRDD<Row> output = dataset.javaRDD().map(new DecodeFunction());
		StructType schema = new StructType(new StructField[]{
			new StructField("sentence", DataTypes.StringType, false, Metadata.empty()),
			new StructField("prediction", DataTypes.StringType, false, Metadata.empty())
		});
		return dataset.sqlContext().createDataFrame(output, schema);
	}

	class DecodeFunction implements Function<Row, Row> {
		private static final long serialVersionUID = 5026042203808959533L;

		@Override
		public Row call(Row row) throws Exception {
			List<String> words = Arrays.asList(row.getString(0).split("\\s+"));
			int n = words.size();
			List<Tuple2<String, String>> sequence = new ArrayList<Tuple2<String, String>>(n);
			for (int i = 0; i < n; i++) {
				sequence.add(new Tuple2<String, String>(words.get(i), "UNK"));
			}
			List<String> partsOfSpeech = decode(sequence);
			StringBuilder sb = new StringBuilder();
			for (String pos : partsOfSpeech) {
				sb.append(pos);
				sb.append(' ');
			}
			return RowFactory.create(row.getString(0), sb.toString().trim());
		}
		
		/**
		 * Finds the best label sequence for an observation sequence.
		 * @param sequence
		 * @return a label sequence.
		 */
		private List<String> decode(List<Tuple2<String, String>> sequence) {
			int n = sequence.size();
			double[][] score = new double[tags.length][n];
					
			for (int j = 0; j < n; j++) {
				LabeledContext context = contextExtractor.extract(sequence, j);
				Tuple2<double[], String> tuple = probability(context);
				double[] prob = tuple._1();
				for (int i = 0; i < prob.length; i++) {
					score[i][j] = prob[i];
				}
				// update the tag at position j for the next incremental extraction
				// 
				sequence.set(j, new Tuple2<String, String>(sequence.get(j)._1(), tuple._2()));
			}
			
			ViterbiDecoder decoder = new ViterbiDecoder(score);
			int[] path = decoder.bestPath();
			List<String> partsOfSpeech = new LinkedList<String>();
			for (int k : path) {
				partsOfSpeech.add(tags[k]);
			}
			return partsOfSpeech;
		}
		
		/**
		 * Computes the probability mass function (pmf) of an unlabeled context. 
		 * @param context
		 * @return a tuple of the pmf over the tagset and the best tag.
		 */
		private Tuple2<double[], String> probability(LabeledContext context) {
			String[] fs = context.getFeatureStrings().toLowerCase().split("\\s+");
			Set<String> fsset = new HashSet<String>();
			for (String f : fs) {
				fsset.add(f);
			}
			List<Tuple2<Integer, Double>> x = new LinkedList<Tuple2<Integer, Double>>();
			for (String f : fsset) {
				Integer i = featureMap.get(f);
				if (i != null) {
					x.add(new Tuple2<Integer, Double>(i, 1.0));
				}
			}
			Vector features = MLUtils.appendBias(Vectors.sparse(featureMap.size(), x));
			int numLabels = tags.length;
			double[] score = new double[numLabels];
			Arrays.fill(score, 0d);
			int maxLabel = 0;
			double maxScore = 0d;
			String word = context.getWord();
			Set<Integer> labels = tagDictionary.get(word);
			if (labels == null) { // this is a rare/unknown word, try all possible labels
				labels = new HashSet<Integer>();
				for (int k = 1; k < numLabels; k++) // k goes from 1 since we do not need to compute score[0], it is always 0.0.
					labels.add(k);
			}
			
			int d = features.size();
			for (int k : labels) {
				if (k > 0) {
					for (int j : features.toSparse().indices())
						score[k] += weights.apply((k-1) * d + j);
					if (score[k] > maxScore) {
						maxScore = score[k];
						maxLabel = k;
					}
				}
			}
			
			// prevent possible numerical overflow error in the case maxScore > 0
			if (maxScore > 0) {
				for (int k = 0; k < numLabels; k++)
					score[k] -= maxScore;
			}
			// normalize the score to get probability
			double z = 0d;
			for (int k = 0; k < numLabels; k++) {
				score[k] = Math.exp(score[k]); 
				z += score[k];
			}
			for (int k = 0; k < numLabels; k++) {
				score[k] /= z;
			}
			return new Tuple2<double[], String>(score, tags[maxLabel]);
		}
	}
	
	
	/* (non-Javadoc)
	 * @see org.apache.spark.ml.PipelineStage#transformSchema(org.apache.spark.sql.types.StructType)
	 */
	@Override
	public StructType transformSchema(StructType schema) {
		return SchemaUtils.appendColumn(schema, new StructField("prediction", DataTypes.StringType, false, Metadata.empty()));
	}

	/* (non-Javadoc)
	 * @see org.apache.spark.ml.util.MLWritable#save(java.lang.String)
	 */
	@Override
	public void save(String path) throws IOException {
		write().overwrite().save(path);
	}

	/* (non-Javadoc)
	 * @see org.apache.spark.ml.util.MLWritable#write()
	 */
	@Override
	public MLWriter write() {
		return new CMMModelWriter(this);
	}
	
	private class CMMModelWriter extends MLWriter {
		CMMModel instance;
		
		public CMMModelWriter(CMMModel instance) {
			this.instance = instance;
		}
		
		@Override
		public void saveImpl(String path) {
			// save metadata and params
			DefaultParamsWriter.saveMetadata(instance, path, sc(), 
					DefaultParamsWriter.saveMetadata$default$4(),
					DefaultParamsWriter.saveMetadata$default$5());

			// save model data: markovOrder, numLabels, weights
			Data data = new Data();
			data.setMarkovOrder(contextExtractor.getMarkovOrder().ordinal()+1);
			data.setWeights(weights);
			data.setTagDictionary(tagDictionary);
			List<Data> list = new LinkedList<Data>();
			list.add(data);
			String dataPath = new Path(path, "data").toString();
			sqlContext().createDataFrame(list, Data.class).write().parquet(dataPath);
			// save pipeline model
			try {
				String pipelinePath = new Path(path, "pipelineModel").toString(); 
				pipelineModel.write().overwrite().save(pipelinePath);
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
	}
	
	public class Data implements Serializable {

		private static final long serialVersionUID = 1L;
		private int markovOrder;
		private Vector weights;
		private Map<String, Set<Integer>> tagDictionary;
		
		public void setMarkovOrder(int markovOrder) {
			this.markovOrder = markovOrder;
		}
		
		public int getMarkovOrder() {
			return markovOrder;
		}
		
		public void setWeights(Vector weights) {
			this.weights = weights;
		}
		
		public Vector getWeights() {
			return weights;
		}
		
		public Map<String, Set<Integer>> getTagDictionary() {
			return tagDictionary;
		}
		
		public void setTagDictionary(Map<String, Set<Integer>> tagDictionary) {
			this.tagDictionary = tagDictionary;
		}
	}
	
	/* (non-Javadoc)
	 * @see org.apache.spark.ml.PipelineStage#toString()
	 */
	@Override
	public String toString() {
		StringBuilder sb = new StringBuilder();
		sb.append("[markovOrder=");
		sb.append(contextExtractor.getMarkovOrder());
		sb.append(", ");
		sb.append("numLabels = ");
		sb.append(tags.length);
		sb.append(", ");
		sb.append(", weights=");
		sb.append(weights.toString());
		sb.append(']');
		return sb.toString();
	}

	/**
	 * Loads a {@link CMMModel} from an external file.
	 * @param path
	 * @return a CMM model.
	 */
	public static CMMModel load(String path) {
		return read().load(path);
	}

	/**
	 * This functions is used in the reflection framework of Spark ML.
	 * @return a {@link MLReader}.
	 */
	public static MLReader<CMMModel> read() {
		return new CMMModelReader();
	}
	
	private static class CMMModelReader extends MLReader<CMMModel> {
		@Override
		public CMMModel load(String path) {
			org.apache.spark.ml.util.DefaultParamsReader.Metadata metadata = DefaultParamsReader.loadMetadata(path, sc(), CMMModel.class.getName());
			String pipelinePath = new Path(path, "pipelineModel").toString();
			PipelineModel pipelineModel = PipelineModel.load(pipelinePath);
			String dataPath = new Path(path, "data").toString();
			DataFrame df = sqlContext().read().format("parquet").load(dataPath);
			Row row = df.select("markovOrder", "weights", "tagDictionary").head();
			// load the Markov order
			MarkovOrder order = MarkovOrder.values()[row.getInt(0)-1];
			// load the weight vector
			Vector w = row.getAs(1);
			// load the tag dictionary
			@SuppressWarnings("unchecked")
			scala.collection.immutable.HashMap<String, WrappedArray<Integer>> td = (scala.collection.immutable.HashMap<String, WrappedArray<Integer>>)row.get(2);
			Map<String, Set<Integer>> tagDict = new HashMap<String, Set<Integer>>();
			Iterator<Tuple2<String, WrappedArray<Integer>>> iterator = td.iterator();
			while (iterator.hasNext()) {
				Tuple2<String, WrappedArray<Integer>> tuple = iterator.next();
				Set<Integer> labels = new HashSet<Integer>();
				scala.collection.immutable.List<Integer> list = tuple._2().toList();
				for (int i = 0; i < list.size(); i++)
					labels.add(list.apply(i));
				tagDict.put(tuple._1(), labels);
			}
			// build a CMM model
			CMMModel model = new CMMModel(pipelineModel, w, order, tagDict);
			DefaultParamsReader.getAndSetParams(model, metadata);
			return model;
		}
	}

}