package edu.uw.easysrl.syntax.training;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import lbfgsb.DifferentiableFunction;
import lbfgsb.FunctionValues;
import lbfgsb.IterationFinishedListener;
import lbfgsb.LBFGSBException;
import lbfgsb.Minimizer;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multiset;

import edu.uw.easysrl.syntax.training.ClassifierTrainer.AbstractFeature;
import edu.uw.easysrl.syntax.training.ClassifierTrainer.AbstractTrainingExample;
import edu.uw.easysrl.util.Util;

/**
 * General-purpose log-linear classifier
 *
 */
public abstract class ClassifierTrainer<T extends AbstractTrainingExample<L>, F extends AbstractFeature<T, L>, L> {

	private static class FeatureKey implements Serializable {
		/**
		 *
		 */
		private static final long serialVersionUID = 1L;
		private final List<Object> values;
		private final int hashCode;

		private FeatureKey(final List<Object> values) {
			this.values = values;
			this.hashCode = values.hashCode();
		}

		@Override
		public boolean equals(final Object other) {
			return hashCode() == other.hashCode() && values.equals(((FeatureKey) other).values);
		}

		@Override
		public int hashCode() {
			return hashCode;
		}

		@Override
		public String toString() {
			return values.toString();
		}
	}

	public static abstract class AbstractTrainingExample<L> {

		public abstract Collection<L> getPossibleLabels();

		public abstract L getLabel();

	}

	private static class CachedTrainingExample<L> {
		private final Map<L, int[]> labelToFeatures;
		private final L label;
		private final String asString;

		public CachedTrainingExample(final Map<L, int[]> labelToFeatures, final L label, final String asString) {
			super();
			this.labelToFeatures = ImmutableMap.copyOf(labelToFeatures);
			this.label = label;
			this.asString = asString;
		}

		@Override
		public String toString() {
			return asString;
		}

	}

	public static class AbstractClassifier<T extends AbstractTrainingExample<L>, F extends AbstractFeature<T, L>, L>
	implements Serializable {
		private static final long serialVersionUID = 1L;

		private final double[] weights;
		private final Collection<F> features;
		private final Map<FeatureKey, Integer> featureToIndex;

		public AbstractClassifier(final double[] weights, final Collection<F> features,
				final Map<FeatureKey, Integer> featureToIndex) {
			super();
			this.weights = weights;
			this.features = features;
			this.featureToIndex = featureToIndex;
		}

		public L classify(final T ex) {

			double bestScore = Double.NEGATIVE_INFINITY;
			L bestLabel = null;

			final Collection<L> frames = ex.getPossibleLabels();

			for (final L label : frames) {
				double score = 0.0;
				for (final F feature : features) {
					score += weights[feature.getIndex(ex, featureToIndex, label)];
				}

				if (score > bestScore) {
					bestLabel = label;
					bestScore = score;
				}
			}

			return bestLabel;
		}

		public double probability(final T ex, final L label) {

			double labelScore = 0.0;
			double totalScore = 0.0;

			final Collection<L> frames = ex.getPossibleLabels();

			for (final L other : frames) {
				double score = 0.0;
				for (final F feature : features) {
					score += weights[feature.getIndex(ex, featureToIndex, other)];
				}

				score = Math.exp(score);
				if (label == other) {
					labelScore = score;
				}

				totalScore += score;
			}

			return labelScore / totalScore;
		}

		private L classify(final CachedTrainingExample<L> ex) {

			double bestScore = Double.NEGATIVE_INFINITY;
			L bestLabel = null;

			for (final Entry<L, int[]> labelToFeatures : ex.labelToFeatures.entrySet()) {
				double score = 0.0;
				for (final int index : labelToFeatures.getValue()) {
					score += weights[index];
				}

				if (score > bestScore) {
					bestLabel = labelToFeatures.getKey();
					bestScore = score;
				}
			}

			return bestLabel;
		}
	}

	public abstract Collection<F> getFeatures();

	private Map<FeatureKey, Integer> makeKeyToIndexMap(final List<T> data, final Collection<F> features,
			final int minFeatureCount) {
		final Multiset<FeatureKey> keyCount = HashMultiset.create();
		final Map<FeatureKey, Integer> result = new HashMap<>();
		for (final T ex : data) {

			for (final F feature : features) {
				final FeatureKey key = feature.getFeatureKey(ex, ex.getLabel());
				keyCount.add(key);
			}

		}

		for (final com.google.common.collect.Multiset.Entry<FeatureKey> entry : keyCount.entrySet()) {
			if (entry.getCount() >= minFeatureCount) {
				result.put(entry.getElement(), result.size());
			}
		}

		for (final F feature : features) {
			result.put(feature.getDefault(), result.size());
		}

		return result;
		// return ImmutableMap.copyOf(result);
	}

	public AbstractClassifier<T, F, L> train(final int minFeatureCount, final double sigmaSquared) throws IOException,
	LBFGSBException {
		final Collection<F> features = getFeatures();
		final List<T> devData = getTrainingData(true);
		final List<T> trainingData = getTrainingData(false);
		final Map<FeatureKey, Integer> featureToIndex = makeKeyToIndexMap(trainingData, features, minFeatureCount);
		final double[] weights = train(features, cache(trainingData, featureToIndex, features),
				cache(devData, featureToIndex, features), sigmaSquared, featureToIndex);

		final AbstractClassifier<T, F, L> classifier = new AbstractClassifier<>(weights, features, featureToIndex);

		return classifier;

	}

	private List<CachedTrainingExample<L>> cache(final List<T> trainingData,
			final Map<FeatureKey, Integer> featureToIndex, final Collection<F> features) {
		final List<CachedTrainingExample<L>> result = new ArrayList<>();
		for (final T ex : trainingData) {
			final Map<L, int[]> labelToFeatures = new HashMap<>();
			for (final L label : ex.getPossibleLabels()) {
				final int[] indices = new int[features.size()];
				int i = 0;
				for (final F feature : features) {
					indices[i] = feature.getIndex(ex, featureToIndex, label);
					i++;
				}
				labelToFeatures.put(label, indices);
			}
			result.add(new CachedTrainingExample<>(labelToFeatures, ex.getLabel(), ex.toString()));
		}
		return result;
	}

	private double[] train(final Collection<F> features, final List<CachedTrainingExample<L>> data,
			final Collection<CachedTrainingExample<L>> dev, final double sigmaSquared,
			final Map<FeatureKey, Integer> featureToIndex) throws LBFGSBException {
		final int numWeights = featureToIndex.size() + 1;
		final double[] weights = new double[numWeights];

		final Minimizer alg = new Minimizer();
		alg.setDebugLevel(5);

		final double[] bestWeights = new double[numWeights];

		alg.getStopConditions().setMaxIterationsInactive();
		alg.getStopConditions().setMaxIterations(250);
		final IterationFinishedListener iterationFinishedListener = new IterationFinishedListener() {
			private int iteration;
			private final double bestScore = Double.NEGATIVE_INFINITY;

			@Override
			public boolean iterationFinished(final double[] newWeights, final double arg1, final double[] arg2) {
				final AbstractClassifier<T, F, L> classifier = new AbstractClassifier<>(newWeights, features,
						featureToIndex);
				int right = 0;

				for (final CachedTrainingExample<L> ex : dev) {
					final L predicted = classifier.classify(ex);
					if (predicted == ex.label) {
						right++;
					} else if (predicted != null) {
						// System.out.println(ex + "\t(" + predicted + ")");
					}
				}
				System.out.println("Iteration: " + iteration++);
				System.out.println("Accuracy = " + Util.twoDP(100.0 * right / dev.size()));

				if (right > bestScore) {
					System.arraycopy(newWeights, 0, bestWeights, 0, newWeights.length);
					;
				}
				return true;
			}
		};
		alg.setIterationFinishedListener(iterationFinishedListener);
		alg.run(new ParallelLossFunction<>(new LossFunction<>(sigmaSquared),
				Runtime.getRuntime().availableProcessors(), data), weights);
		return bestWeights;
	}

	private static class ParallelLossFunction<L> implements DifferentiableFunction {
		private final LossFunction<L> lossFunction;
		private final int numThreads;
		private final List<CachedTrainingExample<L>> trainingData;

		public ParallelLossFunction(final LossFunction<L> lossFunction, final int numThreads,
				final List<CachedTrainingExample<L>> trainingData) {
			super();
			this.lossFunction = lossFunction;
			this.numThreads = numThreads;
			this.trainingData = trainingData;
		}

		@Override
		public IterationResult getValues(final double[] featureWeights) {
			final Collection<Callable<IterationResult>> tasks = new ArrayList<>();

			int totalCorrect = 0;
			final int batchSize = trainingData.size() / numThreads;
			for (final List<CachedTrainingExample<L>> batch : Lists.partition(trainingData, batchSize)) {
				tasks.add(new Callable<IterationResult>() {
					@Override
					public IterationResult call() throws Exception {
						return lossFunction.getValues(featureWeights, batch);
					}
				});
			}

			final ExecutorService executor = Executors.newFixedThreadPool(numThreads);
			List<Future<IterationResult>> results;
			try {
				results = executor.invokeAll(tasks);

				// FunctionValues total = new FunctionValues(0.0, new
				// double[featureWeights.length]);

				final double[] totalGradient = new double[featureWeights.length];
				double totalLoss = 0.0;

				for (final Future<IterationResult> result : results) {
					final IterationResult values = result.get();
					totalLoss += values.functionValue;
					Util.add(totalGradient, values.gradient);
					totalCorrect += values.correctPredictions;
				}
				executor.shutdown(); // always reclaim resources

				System.out.println();

				System.out.println("Training accuracy=" + Util.twoDP(100.0 * totalCorrect / trainingData.size()));
				System.out.println("Loss=" + Util.twoDP(totalLoss));
				return new IterationResult(totalCorrect, totalLoss, totalGradient);

			} catch (InterruptedException | ExecutionException e) {
				throw new RuntimeException(e);
			}
		}
	}

	private static class IterationResult extends FunctionValues {
		public IterationResult(final int correctPredictions, final double functionValue, final double[] gradient) {
			super(functionValue, gradient);
			this.correctPredictions = correctPredictions;
		}

		private final int correctPredictions;

	}

	private static class LossFunction<L> {
		private final double sigmaSquared;

		// private final Multimap<String, SRLFrame> lemmaToSenses;

		private LossFunction(final double sigmaSquared) {
			super();
			this.sigmaSquared = sigmaSquared;
		}

		public IterationResult getValues(final double[] featureWeights,
				final Collection<CachedTrainingExample<L>> trainingData) {
			final double[] modelExpectation = new double[featureWeights.length];
			final double[] goldExpectation = new double[featureWeights.length];

			double loglikelihood = 0.0;
			final AtomicInteger correct = new AtomicInteger();

			for (final CachedTrainingExample<L> trainingExample : trainingData) {
				loglikelihood = loglikelihood
						+ computeExpectationsForTrainingExample(trainingExample, featureWeights, modelExpectation,
								goldExpectation, correct);
			}
			final double[] gradient = Util.subtract(goldExpectation, modelExpectation);

			for (int i = 0; i < featureWeights.length; i++) {
				loglikelihood = loglikelihood - (Math.pow(featureWeights[i], 2) / (2.0 * sigmaSquared));
				gradient[i] = gradient[i] - (featureWeights[i] / sigmaSquared);
			}

			loglikelihood = -loglikelihood;
			for (int i = 0; i < featureWeights.length; i++) {
				gradient[i] = -gradient[i];
			}

			return new IterationResult(correct.get(), loglikelihood, gradient);
		}

		private double computeExpectationsForTrainingExample(final CachedTrainingExample<L> trainingExample,
				final double[] featureWeights, final double[] modelExpectation, final double[] goldExpectation,
				final AtomicInteger correct) {

			if (!trainingExample.labelToFeatures.containsKey(trainingExample.label)) {
				// Gold label is not possible.
				return 0;
			}

			L bestLabel = null;
			double bestScore = Double.NEGATIVE_INFINITY;

			double total = 0.0;
			final Map<L, Double> labelToScore = new HashMap<>();
			for (final Entry<L, int[]> labelToFeatures : trainingExample.labelToFeatures.entrySet()) {
				double score = 0.0;
				for (final int index : labelToFeatures.getValue()) {
					score += featureWeights[index];
				}

				score = Math.exp(score);
				total += score;
				labelToScore.put(labelToFeatures.getKey(), score);

				if (score > bestScore) {
					bestScore = score;
					bestLabel = labelToFeatures.getKey();
				}
			}

			if (bestLabel.equals(trainingExample.label)) {
				correct.getAndIncrement();
			}

			for (final Entry<L, int[]> labelToFeatures : trainingExample.labelToFeatures.entrySet()) {
				final double pLabel = labelToScore.get(labelToFeatures.getKey()) / total;
				for (final int index : labelToFeatures.getValue()) {
					modelExpectation[index] += pLabel;
				}
			}

			final double loglikelihood = Math.log(labelToScore.get(trainingExample.label) / total);

			for (final int index : trainingExample.labelToFeatures.get(trainingExample.label)) {
				goldExpectation[index] += 1;
			}

			return loglikelihood;
		}

	}

	public static abstract class AbstractFeature<T, L> implements Serializable {
		/**
		 *
		 */
		private static final long serialVersionUID = 1L;
		private final String name;
		private final int id;
		private final static AtomicInteger ids = new AtomicInteger();

		private final FeatureKey defaultKey;

		public FeatureKey getFeatureKey(final T trainingExample, final L label) {
			final List<Object> value = new ArrayList<>();
			getValue(value, trainingExample, label);
			value.add(id);
			final FeatureKey result = new FeatureKey(value);
			return result;
		}

		FeatureKey getDefault() {
			return defaultKey;
		}

		public String getName() {
			return name;
		}

		public int getIndex(final T trainingExample, final Map<FeatureKey, Integer> keyToIndex, final L label) {
			Integer result = keyToIndex.get(getFeatureKey(trainingExample, label));
			if (result == null) {
				result = keyToIndex.get(defaultKey);
			}

			return result;
		}

		public AbstractFeature(final String name) {
			super();
			this.name = name;
			this.id = ids.getAndIncrement();
			this.defaultKey = new FeatureKey(Arrays.asList(id));
		}

		public abstract void getValue(List<Object> result, T trainingExample, L label);
	}

	public abstract List<T> getTrainingData(boolean isDev) throws IOException;

}