package com.medallia.word2vec;

import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.ImmutableSortedMultiset;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.collect.Multisets;
import com.google.common.primitives.Doubles;
import com.medallia.word2vec.util.AC;
import com.medallia.word2vec.util.ProfilingTimer;
import org.apache.commons.logging.Log;
import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener.Stage;
import com.medallia.word2vec.huffman.HuffmanCoding;
import com.medallia.word2vec.huffman.HuffmanCoding.HuffmanNode;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer.NeuralNetworkModel;

import java.util.List;
import java.util.Map;

/** Responsible for training a word2vec model */
class Word2VecTrainer {
	private final int minFrequency;
	private final Optional<Multiset<String>> vocab;
	private final NeuralNetworkConfig neuralNetworkConfig;
	
	Word2VecTrainer(
			Integer minFrequency,
			Optional<Multiset<String>> vocab,
			NeuralNetworkConfig neuralNetworkConfig) {
		this.vocab = vocab;
		this.minFrequency = minFrequency;
		this.neuralNetworkConfig = neuralNetworkConfig;
	}

	/** @return {@link Multiset} containing unique tokens and their counts */
	private static Multiset<String> count(Iterable<String> tokens) {
		Multiset<String> counts = HashMultiset.create();
		for (String token : tokens)
			counts.add(token);
		return counts;
	}
	
	/** @return Tokens with their count, sorted by frequency decreasing, then lexicographically ascending */
	private ImmutableMultiset<String> filterAndSort(final Multiset<String> counts) {
		// This isn't terribly efficient, but it is deterministic
		// Unfortunately, Guava's multiset doesn't give us a clean way to order both by count and element
		return Multisets.copyHighestCountFirst(
				ImmutableSortedMultiset.copyOf(
						Multisets.filter(
								counts,
								new Predicate<String>() {
									@Override
									public boolean apply(String s) {
										return counts.count(s) >= minFrequency;
									}
								}
						)
				)
		);
		
	}
	
	/** Train a model using the given data */
	Word2VecModel train(Log log, TrainingProgressListener listener, Iterable<List<String>> sentences) throws InterruptedException {
		try (ProfilingTimer timer = ProfilingTimer.createLoggingSubtasks(log, "Training word2vec")) {
			final Multiset<String> counts;
			
			try (AC ac = timer.start("Acquiring word frequencies")) {
				listener.update(Stage.ACQUIRE_VOCAB, 0.0);
				counts = (vocab.isPresent())
							? vocab.get()
							: count(Iterables.concat(sentences));
			}
			
			final ImmutableMultiset<String> vocab;
			try (AC ac = timer.start("Filtering and sorting vocabulary")) {
				listener.update(Stage.FILTER_SORT_VOCAB, 0.0);
				vocab = filterAndSort(counts);
			}
			
			final Map<String, HuffmanNode> huffmanNodes;
			try (AC task = timer.start("Create Huffman encoding")) {
				huffmanNodes = new HuffmanCoding(vocab, listener).encode();
			}
			
			final NeuralNetworkModel model;
			try (AC task = timer.start("Training model %s", neuralNetworkConfig)) {
				model = neuralNetworkConfig.createTrainer(vocab, huffmanNodes, listener).train(sentences);
			}
			
			return new Word2VecModel(vocab.elementSet(), model.layerSize(), Doubles.concat(model.vectors()));
		}
	}
}