package edu.uw.easysrl.main;

import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.InputMismatchException;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.StringTokenizer;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import edu.uw.Taggerflow;
import edu.uw.TaggerflowProtos.TaggedSentence;
import edu.uw.TaggerflowProtos.TaggedToken;
import edu.uw.easysrl.main.EasySRL.InputFormat;
import edu.uw.easysrl.syntax.grammar.Category;
import edu.uw.easysrl.syntax.grammar.SyntaxTreeNode.SyntaxTreeNodeLeaf;
import edu.uw.easysrl.syntax.tagger.Tagger.ScoredCategory;
import edu.uw.easysrl.syntax.tagger.TaggerflowLSTM;
import edu.uw.easysrl.util.Util;

public abstract class InputReader {

	public static class InputWord implements Serializable {
		private static final long serialVersionUID = -4110997736066926795L;

		public InputWord(final String word, final String pos, final String ner) {
			this.word = word;
			this.pos = pos;
			this.ner = ner;
		}

		InputWord(final String word) {
			this(word, null, null);
		}

		public final String word;
		public final String pos;
		public final String ner;

		public static List<InputWord> listOf(final String... words) {
			final List<InputWord> result = new ArrayList<>(words.length);
			for (int i = 0; i < words.length; i++) {
				result.add(new InputWord(words[i]));
			}
			return result;
		}

		public static List<InputWord> listOf(final List<String> words) {
			final List<InputWord> result = new ArrayList<>(words.size());
			for (final String word : words) {
				result.add(new InputWord(word));
			}
			return result;
		}

		public static List<InputWord> fromLeaves(final List<SyntaxTreeNodeLeaf> leaves) {
			final List<InputWord> result = new ArrayList<>(leaves.size());
			for (final SyntaxTreeNodeLeaf leaf : leaves) {
				result.add(InputWord.valueOf(leaf));
			}

			return result;
		}

		private static InputWord valueOf(final SyntaxTreeNodeLeaf leaf) {
			return new InputWord(leaf.getWord(), leaf.getPos(), leaf.getNER());
		}

		@Override
		public int hashCode() {
			return Objects.hash(word, ner, pos);
		}

		@Override
		public boolean equals(final Object obj) {
			final InputWord other = (InputWord) obj;
			return Objects.equals(word, other.word) && Objects.equals(ner, other.ner) && Objects.equals(pos, other.pos);
		}

		@Override
		public String toString() {
			return word + (pos != null ? "|" + pos : "") + (ner != null ? "|" + ner : "");
		}
	}

	public Iterable<InputToParser> readFile(final File input) throws IOException {
		final Iterator<String> inputIt = Util.readFileLineByLine(input);

		return () -> new Iterator<InputToParser>() {
			private InputToParser next = getNext();

			@Override
			public boolean hasNext() {
				return next != null;
			}

			private InputToParser getNext() {
				while (inputIt.hasNext()) {
					final String nextLine = inputIt.next();
					if (!nextLine.startsWith("#") && !nextLine.isEmpty()) {
						// Skip commented or empty lines;
						return readInput(nextLine);
					}
				}

				return null;
			}

			@Override
			public InputToParser next() {
				final InputToParser result = next;
				next = getNext();
				return result;
			}

			@Override
			public void remove() {
				throw new UnsupportedOperationException();
			}
		};
	}

	public Iterator<InputToParser> readInput(final File file) throws IOException {
		final Iterator<String> lines = Util.readFileLineByLine(file);
		return new Iterator<InputToParser>() {

			@Override
			public boolean hasNext() {
				return lines.hasNext();
			}

			@Override
			public InputToParser next() {
				return readInput(lines.next());
			}
		};
	}

	public abstract InputToParser readInput(String line);

	public static class InputToParser implements Serializable {
		private static final long serialVersionUID = 1L;

		private final List<InputWord> words;
		private final boolean isAlreadyTagged;

		public InputToParser(final List<InputWord> words, final List<Category> goldCategories,
				final List<List<ScoredCategory>> inputSupertags, final boolean isAlreadyTagged) {
			this.words = words;
			this.goldCategories = goldCategories;
			this.inputSupertags = inputSupertags;
			this.isAlreadyTagged = isAlreadyTagged;
		}

		private final List<Category> goldCategories;
		private final List<List<ScoredCategory>> inputSupertags;

		public int length() {
			return words.size();
		}

		/**
		 * If true, the Parser should not supertag the data itself, and use getInputSupertags() instead.
		 */
		public boolean isAlreadyTagged() {
			return isAlreadyTagged;
		}

		public List<List<ScoredCategory>> getInputSupertags() {
			return inputSupertags;
		}

		public boolean haveGoldCategories() {
			return getGoldCategories() != null;
		}

		public List<Category> getGoldCategories() {
			return goldCategories;
		}

		public List<InputWord> getInputWords() {
			return words;

		}

		public String getWordsAsString() {
			final StringBuilder result = new StringBuilder();
			for (final InputWord word : words) {
				result.append(word.word + " ");
			}

			return result.toString().trim();
		}

		public static InputToParser fromTokens(final List<String> tokens) {
			final List<InputWord> inputWords = new ArrayList<>(tokens.size());
			for (final String word : tokens) {
				inputWords.add(new InputWord(word, null, null));
			}
			return new InputToParser(inputWords, null, null, false);
		}

		public boolean isPOStagged() {
			return words.size() == 0 || words.get(0).pos != null;
		}
	}

	private static class RawInputReader extends InputReader {

		@Override
		public InputToParser readInput(final String line) {
			// TODO quotes
			return InputToParser
					.fromTokens(Arrays.asList(line.replaceAll("\"", "").replaceAll("  +", " ").trim().split(" ")));
		}
	}

	/**
	 * Reads input tagged with a distribution of supertags. The format can be produced running the C&C supertagger with
	 * the output format: %w\t%p\t%S|\n
	 *
	 * Example: Pierre NNP 0 N/N 0.99525070603732 N 0.0026450007306822|Vinken NNP 0 N 0.70743834018551 S/S...
	 */

	private static class GoldInputReader extends InputReader {

		@Override
		public InputToParser readInput(final String line) {
			final List<Category> result = new ArrayList<>();
			final String[] goldEntries = line.split(" ");
			final List<InputWord> words = new ArrayList<>(goldEntries.length);
			final List<List<ScoredCategory>> supertags = new ArrayList<>();
			for (final String entry : goldEntries) {
				final String[] goldFields = entry.split("\\|");

				if (goldFields[0].equals("\"")) {
					continue; // TODO quotes
				}
				if (goldFields.length < 3) {
					throw new InputMismatchException("Invalid input: expected \"word|POS|SUPERTAG\" but was: " + entry);
				}

				final String word = goldFields[0];
				final String pos = goldFields[1];
				final Category category = Category.valueOf(goldFields[2]);
				words.add(new InputWord(word, pos, null));
				result.add(category);
				supertags.add(Collections.singletonList(new ScoredCategory(category, Double.MAX_VALUE)));
			}
			return new InputToParser(words, result, supertags, false);
		}

		private GoldInputReader() {
		}
	}

	public static class SupertaggedInputReader extends InputReader {
		private final List<Category> cats;

		// Word|N=3|NP=2
		@Override
		public InputToParser readInput(final String line) {
			final List<Category> result = new ArrayList<>();
			final String[] goldEntries = line.split(" ");
			final List<InputWord> words = new ArrayList<>(goldEntries.length);
			final List<List<ScoredCategory>> supertags = new ArrayList<>();
			for (final String entry : goldEntries) {
				// final String[] goldFields = entry.split("\\|");
				final StringTokenizer tokenizer = new StringTokenizer(entry, "|");

				final String word = tokenizer.nextToken();// goldFields[0];
				// final String pos = goldFields[1];
				List<ScoredCategory> tagDist = new ArrayList<>();
				words.add(new InputWord(word));

				// final String[] tags = goldFields[2].split("\\|");
				// result.add(Category.valueOf(tags[0]));
				// for (int i = 1; i < goldFields.length; i++) {
				while (tokenizer.hasMoreTokens()) {
					final String tagAndScore = tokenizer.nextToken();
					final int equals = tagAndScore.indexOf("=");
					final Category category = Category.valueOf(tagAndScore.substring(0, equals));
					//cats.get(Integer.valueOf(tagAndScore.substring(0, equals)));

					tagDist.add(new ScoredCategory(category, Double.valueOf(tagAndScore.substring(equals + 1))));
				}
				Collections.sort(tagDist);

				final double bestScore = tagDist.get(0).getScore();
				final double threshold = 0.000001 * Math.exp(bestScore);
				for (int i = 1; i < tagDist.size(); i++) {
					if (Math.exp(tagDist.get(i).getScore()) < threshold) {
						tagDist = tagDist.subList(0, i);
						break;
					}
				}

				supertags.add(tagDist);
			}
			return new InputToParser(words, result, supertags, true);
		}

		public SupertaggedInputReader(final List<Category> cats) {
			this.cats = cats;
		}
	}

	private static class POSTaggedInputReader extends InputReader {

		@Override
		public InputToParser readInput(final String line) {
			final String[] taggedEntries = line.split(" ");
			final List<InputWord> inputWords = new ArrayList<>(taggedEntries.length);
			for (final String entry : taggedEntries) {
				final String[] taggedFields = entry.split("\\|");

				if (taggedFields.length < 2) {
					throw new InputMismatchException("Invalid input: expected \"word|POS\" but was: " + entry);
				}
				if (taggedFields[0].equals("\"")) {
					continue; // TODO quotes
				}
				inputWords.add(new InputWord(taggedFields[0], taggedFields[1], null));
			}
			return new InputToParser(inputWords, null, null, false);
		}
	}

	private static class POSandNERTaggedInputReader extends InputReader {
		@Override
		public InputToParser readInput(final String line) {
			final String[] taggedEntries = line.split(" ");
			final List<InputWord> inputWords = new ArrayList<>(taggedEntries.length);
			for (final String entry : taggedEntries) {
				final String[] taggedFields = entry.split("\\|");

				if (taggedFields[0].equals("\"")) {
					continue; // TODO quotes
				}
				if (taggedFields.length < 3) {
					throw new InputMismatchException("Invalid input: expected \"word|POS|NER\" but was: " + entry + "\n"
							+ "The C&C can produce this format using: \"bin/pos -model models/pos | bin/ner -model models/ner -ofmt \"%w|%p|%n \\n\"\"");
				}
				inputWords.add(new InputWord(taggedFields[0], taggedFields[1], taggedFields[2]));
			}
			return new InputToParser(inputWords, null, null, false);
		}
	}

	public static InputReader make(final InputFormat inputFormat) {
		switch (inputFormat) {
		case TOKENIZED:
			return new RawInputReader();
		case GOLD:
			return new GoldInputReader();
		case POSTAGGED:
			return new POSTaggedInputReader();
		case POSANDNERTAGGED:
			return new POSandNERTaggedInputReader();
		default:
			throw new Error("Unknown input format: " + inputFormat);
		}
	}

	/**
	 * Runs a TensorFlow library which deals with loading and tagging the file.
	 */
	public static class TensorFlowInputReader extends InputReader {
		private final Taggerflow tagger;
		private final List<Category> categories;
		private final int maxBatchSize;
		private final Stopwatch gpuTime = Stopwatch.createUnstarted();

		public TensorFlowInputReader(final File folder, final List<Category> categories, final int maxBatchSize) {
			tagger = new Taggerflow(folder, 1e-5);
			this.categories = categories;
			this.maxBatchSize = maxBatchSize;
		}

		@Override
		public InputToParser readInput(final String line) {
			throw new UnsupportedOperationException("TensorFlowInputReader can only be used in batch mode");
		}

		public long getSupertaggingTime(final TimeUnit timeUnit) {
			return gpuTime.elapsed(timeUnit);
		}

		public void resetSupertaggingTime() {
			gpuTime.reset();
		}

		@Override
		public Iterable<InputToParser> readFile(final File file) throws IOException {
			return () -> {
				gpuTime.start();
				final Iterator<TaggedSentence> taggedSentenceIterator = tagger
						.predict(file.getAbsolutePath(), maxBatchSize).iterator();
				gpuTime.stop();

				return new Iterator<InputToParser>() {
					@Override
					public boolean hasNext() {
						return taggedSentenceIterator.hasNext();
					}

					@Override
					public InputToParser next() {
						gpuTime.start();
						final TaggedSentence sentence = taggedSentenceIterator.next();
						gpuTime.stop();
						final List<List<ScoredCategory>> tagDist = TaggerflowLSTM
								.getScoredCategories(sentence, categories);
						final List<InputWord> words = sentence.getTokenList().stream().map(TaggedToken::getWord)
								.map(InputWord::new).collect(Collectors.toList());

						if (words.size() == 0) {
							return next();// FIXME sentences not parsed by TensorFlow
						}

						Preconditions.checkState(words.size() == tagDist.size());
						return new InputToParser(words, null, tagDist, true);
					}
				};
			};
		}
	}

}