package com.medallia.word2vec;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

import java.io.File;
import java.io.IOException;
import java.util.List;

import org.apache.commons.io.FileUtils;
import org.apache.thrift.TException;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.medallia.word2vec.Searcher.Match;
import com.medallia.word2vec.Searcher.UnknownWordException;
import com.medallia.word2vec.Word2VecTrainerBuilder.TrainingProgressListener;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkType;
import com.medallia.word2vec.thrift.Word2VecModelThrift;
import com.medallia.word2vec.util.Common;
import com.medallia.word2vec.util.ThriftUtils;

/**
 * Tests for {@link Word2VecModel} and related classes.
 * <p>
 * Note that the implementation is expected to be deterministic if numThreads is
 * set to 1
 */
public class Word2VecTest {
	@Rule
	public ExpectedException expected = ExpectedException.none();

	/** Clean up after a test run */
	@After
	public void after() {
		// Unset the interrupted flag to avoid polluting other tests
		Thread.interrupted();
	}

	/** Test {@link NeuralNetworkType#CBOW} */
	@Test
	public void testCBOW() throws IOException, TException, InterruptedException {
		assertModelMatches("cbowBasic.model",
				Word2VecModel.trainer()
						.setMinVocabFrequency(6)
						.useNumThreads(1)
						.setWindowSize(8)
						.type(NeuralNetworkType.CBOW)
						.useHierarchicalSoftmax()
						.setLayerSize(25)
						.setDownSamplingRate(1e-3)
						.setNumIterations(1)
						.train(testData())
		);
	}

	/** Test {@link NeuralNetworkType#CBOW} with 15 iterations */
	@Test
	public void testCBOWwith15Iterations() throws IOException, TException, InterruptedException {
		assertModelMatches("cbowIterations.model",
				Word2VecModel.trainer()
					.setMinVocabFrequency(5)
					.useNumThreads(1)
					.setWindowSize(8)
					.type(NeuralNetworkType.CBOW)
					.useHierarchicalSoftmax()
					.setLayerSize(25)
					.useNegativeSamples(5)
					.setDownSamplingRate(1e-3)
					.setNumIterations(15)
					.train(testData())
			);
	}

	/** Test {@link NeuralNetworkType#SKIP_GRAM} */
	@Test
	public void testSkipGram() throws IOException, TException, InterruptedException {
		assertModelMatches("skipGramBasic.model",
				Word2VecModel.trainer()
					.setMinVocabFrequency(6)
					.useNumThreads(1)
					.setWindowSize(8)
					.type(NeuralNetworkType.SKIP_GRAM)
					.useHierarchicalSoftmax()
					.setLayerSize(25)
					.setDownSamplingRate(1e-3)
					.setNumIterations(1)
					.train(testData())
			);
	}

	/** Test {@link NeuralNetworkType#SKIP_GRAM} with 15 iterations */
	@Test
	public void testSkipGramWith15Iterations() throws IOException, TException, InterruptedException {
		assertModelMatches("skipGramIterations.model",
				Word2VecModel.trainer()
					.setMinVocabFrequency(6)
					.useNumThreads(1)
					.setWindowSize(8)
					.type(NeuralNetworkType.SKIP_GRAM)
					.useHierarchicalSoftmax()
					.setLayerSize(25)
					.setDownSamplingRate(1e-3)
					.setNumIterations(15)
					.train(testData())
			);
	}

	/** Test that we can interrupt the huffman encoding process */
	@Test
	public void testInterruptHuffman() throws IOException, InterruptedException {
		expected.expect(InterruptedException.class);
		trainer()
			.type(NeuralNetworkType.SKIP_GRAM)
			.setNumIterations(15)
			.setListener(new TrainingProgressListener() {
					@Override public void update(Stage stage, double progress) {
						if (stage == Stage.CREATE_HUFFMAN_ENCODING)
							Thread.currentThread().interrupt();
						else if (stage == Stage.TRAIN_NEURAL_NETWORK)
							fail("Should not have reached this stage");
					}
				})
			.train(testData());
	}

	/** Test that we can interrupt the neural network training process */
	@Test
	public void testInterruptNeuralNetworkTraining() throws InterruptedException, IOException {
		expected.expect(InterruptedException.class);
		trainer()
			.type(NeuralNetworkType.SKIP_GRAM)
			.setNumIterations(15)
			.setListener(new TrainingProgressListener() {
					@Override public void update(Stage stage, double progress) {
						if (stage == Stage.TRAIN_NEURAL_NETWORK)
							Thread.currentThread().interrupt();
					}
				})
			.train(testData());
	}

  /**
   * Test the search results are deterministic Note the actual values may not
   * make sense since the model we train isn't tuned
   */
	@Test
	public void testSearch() throws InterruptedException, IOException, UnknownWordException {
		Word2VecModel model = trainer()
			.type(NeuralNetworkType.SKIP_GRAM)
			.train(testData());

		List<Match> matches = model.forSearch().getMatches("anarchism", 5);

		assertEquals(
				ImmutableList.of("anarchism", "feminism", "trouble", "left", "capitalism"),
				Lists.transform(matches, Match.TO_WORD)
			);
	}

  /**
   * Test that the model can retrieve words by a vector.
   */
  @Test
    public void testGetWordByVector() throws InterruptedException, IOException, UnknownWordException {
        Word2VecModel model = trainer()
            .type(NeuralNetworkType.SKIP_GRAM)
            .train(testData());

        // This vector defines the word "anarchism" in the given model.
        double[] vectors = new double[] { 0.11410251703652753, 0.271180824514185, 0.03748515103121994, 0.20888126888511183, 0.009713531343874777, 0.4769425625416319, 0.1431890482445165, -0.1917578875330224, -0.33532561802423366,
            -0.08794543238607992, 0.20404593606213406, 0.26170074241479385, 0.10020961212561065, 0.11400571893146201, -0.07846426915175395, -0.19404092647187385, 0.13381991303455204, -4.6749635342694615E-4, -0.0820905789076496,
            -0.30157145455251866, 0.3652037905836543, -0.16466827556950117, -0.012965932276668056, 0.09896568721267748, -0.01925755122093615 };

        List<Match> matches = model.forSearch().getMatches(vectors, 5);

        assertEquals(
                ImmutableList.of("anarchism", "feminism", "trouble", "left", "capitalism"),
                Lists.transform(matches, Match.TO_WORD)
            );
    }
  
  /**
   * Test that the model can retrieve words by a vector.
   */
  @Test
    public void testGetWordByNotExistantVector() throws InterruptedException, IOException, UnknownWordException {
        Word2VecModel model = trainer()
            .type(NeuralNetworkType.SKIP_GRAM)
            .train(testData());

        double[] vectors = new double[] { 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
            0, 0, 0, 0, 0, 0 };

        List<Match> matches = model.forSearch().getMatches(vectors, 5);

        assertEquals(
                ImmutableList.of("the", "of", "and", "in", "a"),
                Lists.transform(matches, Match.TO_WORD)
            );
    }

	/** Test reading Word2Vec C version txt output format into this library */
	@Test
	public void testTxtModelRead() throws IOException, UnknownWordException {
		String filename = "word2vec.c.output.model.txt";
		Word2VecModel word2VecModel = Word2VecModel.fromTextFile(filename, Common.readResource(Word2VecTest.class, filename));
    assertEquals(0.9927725293757652, word2VecModel.forSearch().cosineDistance("three", "five"), 1e-5);
	}

	/** @return {@link Word2VecTrainer} which by default uses all of the supported features */
	@VisibleForTesting
	public static Word2VecTrainerBuilder trainer() {
		return Word2VecModel.trainer()
			.setMinVocabFrequency(6)
			.useNumThreads(1)
			.setWindowSize(8)
			.type(NeuralNetworkType.CBOW)
			.useHierarchicalSoftmax()
			.setLayerSize(25)
			.setDownSamplingRate(1e-3)
			.setNumIterations(1);
	}

	/** @return raw test dataset. The tokens are separated by newlines. */
	@VisibleForTesting
	public static Iterable<List<String>> testData() throws IOException {
		List<String> lines = Common.readResource(Word2VecTest.class, "word2vec.short.txt");
		Iterable<List<String>> partitioned = Iterables.partition(lines, 1000);
		return partitioned;
	}

	private void assertModelMatches(String expectedResource, Word2VecModel model) throws TException {
		final String thrift;
		try {
			thrift = Common.readResourceToStringChecked(getClass(), expectedResource);
		} catch (IOException ioe) {
			String filename = "/tmp/" + expectedResource;
			try {
				FileUtils.writeStringToFile(
						new File(filename),
						ThriftUtils.serializeJson(model.toThrift())
				);
			} catch (IOException e) {
				throw new AssertionError("Could not read resource " + expectedResource + " and could not write expected output to /tmp");
			}
			throw new AssertionError("Could not read resource " + expectedResource + " wrote to " + filename);
		}

		Word2VecModelThrift expected = ThriftUtils.deserializeJson(
				new Word2VecModelThrift(),
				thrift
		);

		assertEquals("Mismatched vocab", expected.getVocab().size(), Iterables.size(model.getVocab()));

		assertEquals(expected, model.toThrift());
	}
}