from __future__ import print_function

import numpy as np
import theano

import argparse
import logging
import itertools
from copy import deepcopy
import os
import sys

from data_generator import VisualWordDataGenerator
import models

# Set up logger
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger(__name__)

# Dimensionality of image feature vector
IMG_FEATS = 4096


class ExtractFinalHiddenStateActivations:

    def __init__(self, args):
        self.args = args
        self.args.generate_from_N_words = 0  # Default 0
        self.vocab = dict()
        self.unkdict = dict()
        self.counter = 0
        self.maxSeqLen = 0
        self.MAX_HT = self.args.generation_timesteps - 1

        # consistent with models.py
        # maybe use_sourcelang isn't applicable here?
        self.use_sourcelang = args.source_vectors is not None
        self.use_image = not args.no_image

        if self.args.debug:
            theano.config.optimizer = 'None'
            theano.config.exception_verbosity = 'high'

        self.source_type = "predicted" if self.args.use_predicted_tokens else "gold"
        self.source_encoder = "mt_enc" if self.args.no_image else "vis_enc"
        self.source_dim = self.args.hidden_size

        self.h5_dataset_str = "%s-hidden_feats-%s-%d" % (self.source_type,
                                                         self.source_encoder,
                                                         self.source_dim)
        logger.info("Serialising into %s" % self.h5_dataset_str)

    def get_hidden_activations(self):
        '''
        In the model, we will merge the VGG image representation with
        the word embeddings. We need to feed the data as a list, in which
        the order of the elements in the list is _crucial_.
        '''

        self.data_generator = VisualWordDataGenerator(self.args,
                                                      self.args.dataset)
        self.args.checkpoint = self.find_best_checkpoint()
        self.data_generator.set_vocabulary(self.args.checkpoint)
        self.vocab_len = len(self.data_generator.index2word)
        t = self.args.generation_timesteps if self.args.use_predicted_tokens else self.data_generator.max_seq_len

        m = models.NIC(self.args.embed_size, self.args.hidden_size,
                       self.vocab_len,
                       self.args.dropin,
                       self.args.optimiser, self.args.l2reg,
                       weights=self.args.checkpoint,
                       gru=self.args.gru,
                       t=t)

        self.fhs = m.buildHSNActivations(use_image=self.use_image)
        if self.args.use_predicted_tokens and self.args.no_image == False:
            gen_m = models.NIC(self.args.embed_size, self.args.hidden_size,
                               self.vocab_len,
                               self.args.dropin,
                               self.args.optimiser, self.args.l2reg,
                               weights=self.args.checkpoint,
                               gru=self.args.gru,
                               t=self.args.generation_timesteps)
            self.full_model = gen_m.buildKerasModel(use_image=self.use_image)

        self.new_generate_activations('train')
        self.new_generate_activations('val')
        #self.new_generate_activations('test')

    def new_generate_activations(self, split):
        '''
        Generate and serialise final-timestep hidden state activations
        into --dataset.
        TODO: we should be able to serialise predicted final states instead of
        gold-standard final states for val and test data.
        '''
        logger.info("%s: extracting final hidden state activations from this model", split)

        # Prepare the data generator based on whether we're going to work with
        # the gold standard input tokens or the automatically predicted tokens
        if self.args.use_predicted_tokens:
            the_generator = self.data_generator.generation_generator(split=split)
        else:
            the_generator = self.data_generator.fixed_generator(split=split)

        counter = 0
        hidden_states = []
        batch_start = 0
        batch_end = 0
        for data in the_generator:
            if self.args.use_predicted_tokens:
                tokens = self.get_predicted_tokens(data)
                data[0]['text'] = self.set_text_arrays(tokens, data[0]['text'])

            # We extract the FHS from either the oracle input tokens
            hsn = self.fhs.predict({'text': data[0]['text'],
                                    'img': data[0]['img']},
                                   batch_size=self.args.batch_size,
                                    verbose=1)

            for idx, h in enumerate(hsn):
                # get final_hidden index on a sentence-by-sentence
                # basis by searching for the first <E> in each trainY
                eos = False
                for widx, warr in enumerate(data[1]['output'][idx]):
                    w = np.argmax(warr)
                    if self.data_generator.index2word[w] == "<E>":
                        final_hidden = h[widx]
                        hidden_states.append(final_hidden)
                        eos = True
                        logger.debug(widx)
                        break
                if not eos:
                    final_hidden = h[self.MAX_HT]
                    hidden_states.append(final_hidden)
                batch_end += 1

            # Note: serialisation happens over training batches too.
            # now serialise the hidden representations in the h5
            self.to_h5_indices(split, data[0]['indices'], hidden_states)

            batch_start = batch_end
            counter += len(hidden_states)
            hidden_states = []
            logger.info("Processed %d instances" % counter)
            if batch_end >= self.data_generator.split_sizes[split]:
                break

#        elif split == 'val' or split == "test":
#            hidden_states = []
#            batch_start = 0
#            batch_end = 0
#            for data in the_generator:
#		if self.args.use_predicted_tokens:
#		    tokens = self.get_predicted_tokens(data)
#                    data['text'] = self.set_text_arrays(tokens, data['text'])
#
#                # We extract the FHS from either the oracle input tokens
#                hsn = self.fhs.predict({'text': data['text'],
#                                        'img': data['img']},
#                                       batch_size=self.args.batch_size,
#                                       verbose=1)
#
#                for idx, h in enumerate(hsn['rnn']):
#                    # get final_hidden index on a sentence-by-sentence
#                    # basis by searching for the first <E> in each trainY
#                    eos = False
#                    for widx, warr in enumerate(data['output'][idx]):
#                        w = np.argmax(warr)
#                        if self.data_generator.index2word[w] == "<E>":
#                            final_hidden = h[widx]
#                            hidden_states.append(final_hidden)
#                            eos = True
#                            break
#                    if not eos:
#                        final_hidden = h[self.MAX_HT]
#                        hidden_states.append(final_hidden)
#                    batch_end += 1
#
#                # Note: serialisation happens over training batches too.
#                # now serialise the hidden representations in the h5
#                self.to_h5_indices(split, data['indices'], hidden_states)
#
#                batch_start = batch_end
#		counter += len(hidden_states)
#                hidden_states = []
#		logger.info("Processed %d instances" % counter)
#                if batch_end >= self.data_generator.split_sizes[split]:
#                    break

    def get_predicted_tokens(self, data):
        """
        We're not going to work with the gold standard input tokens.
        Instead we're going to automatically predict them and then extract
        the final hidden state from the inferred data.

        Helper function used by new_generate_activations().
        """
        # We are going to arg max decode a sequence.
        start_gen = self.args.generate_from_N_words + 1  # include BOS

        text = deepcopy(data[0]['text'])
        # Append the first start_gen words to the complete_sentences list
        # for each instance in the batch.
        complete_sentences = [[] for _ in range(text.shape[0])]
        for t in range(start_gen):  # minimum 1
            for i in range(text.shape[0]):
                w = np.argmax(text[i, t])
                complete_sentences[i].append(self.data_generator.index2word[w])
        del data[0]['text']
        text = self.reset_text_arrays(text, start_gen)
        Y_target = data[1]['output']
        data[0]['text'] = text

        for t in range(start_gen, self.args.generation_timesteps):
            logger.debug("Input token: %s" %
                    self.data_generator.index2word[np.argmax(data[0]['text'][0,t-1])])
            preds = self.full_model.predict(data[0], verbose=0)

            # Look at the last indices for the words.
            next_word_indices = np.argmax(preds[:, t-1], axis=1)
            logger.debug("Predicted token: %s" % self.data_generator.index2word[next_word_indices[0]])
            # update array[0]/sentence-so-far with generated words.
            for i in range(len(next_word_indices)):
                data[0]['text'][i, t, next_word_indices[i]] = 1.
            next_words = [self.data_generator.index2word[x] for x in next_word_indices]
            for i in range(len(next_words)):
                complete_sentences[i].append(next_words[i])

        # extract each sentence until it hits the first end-of-string token
        pruned_sentences = []
        for s in complete_sentences:
            pruned_sentences.append([x for x
                                    in itertools.takewhile(
                                        lambda n: n != "<E>", s)])
        return pruned_sentences

    def set_text_arrays(self, predicted_tokens, text_arrays):
        """ Set the values of the text tokens in the text arrays
        based on the tokens predicted by the model.

        Helper function used by new_generate_activations() """
        pidx = 0
        new_arrays = deepcopy(text_arrays)
        for pairs in zip(predicted_tokens, text_arrays):
            toks = pairs[0]
            struct = pairs[1]
            for tidx, t in enumerate(toks):
                struct[tidx, self.data_generator.word2index[t]] = 1
            new_arrays[pidx] = struct
            pidx += 1
        return new_arrays

    def reset_text_arrays(self, text_arrays, fixed_words=1):
        """ Reset the values in the text data structure to zero so we cannot
        accidentally pass them into the model.

        Helper function for generate_sentences().
         """
        reset_arrays = deepcopy(text_arrays)
        reset_arrays[:,fixed_words:, :] = 0
        return reset_arrays

#    def make_generation_arrays(self, prefix, fixed_words,
#                               predicted_tokens=None):
#        '''
#        Create arrays that are used as input for generation / activation.
#        '''
#
#
#        if predicted_tokens is not None:
#            input_data, targets = self.data_generator.get_data_by_split(prefix,
#                                           self.use_sourcelang, self.use_image)
#            logger.info("Initialising generation arrays with predicted tokens")
#            gen_input_data = deepcopy(input_data)
#            tokens = gen_input_data[0]
#            tokens[:, fixed_words, :] = 0  # reset the inputs
#            for prediction, words, tgt in zip(predicted_tokens, tokens, targets):
#                for idx, t in enumerate(prediction):
#                    words[idx, self.data_generator.word2index[t]] = 1.
#            targets = self.data_generator.get_target_descriptions(tokens)
#            return gen_input_data, targets
#
#        else:
#            # Replace input words (input_data[0]) with zeros for generation,
#            # except for the first args.generate_from_N_words
#            # NOTE: this will include padding and BOS steps (fixed_words has been
#            # incremented accordingly already in generate_sentences().)
#            input_data = self.data_generator.get_generation_data_by_split(prefix,
#                                           self.use_sourcelang, self.use_image)
#            logger.info("Initialising with the first %d gold words (incl BOS)",
#                        fixed_words)
#            gen_input_data = deepcopy(input_data)
#            gen_input_data[0][:, fixed_words:, :] = 0
#            return gen_input_data
#
#    def generate_sentences(self, split, arrays=None):
#        """
#        Generates descriptions of images for --generation_timesteps
#        iterations through the LSTM. Each input description is clipped to
#        the first <BOS> token, or, if --generate_from_N_words is set, to the
#        first N following words (N + 1 BOS token).
#        This process can be additionally conditioned
#        on source language hidden representations, if provided by the
#        --source_vectors parameter.
#        The output is clipped to the first EOS generated, if it exists.
#
#        TODO: beam search
#        TODO: duplicated method with generate.py and Callbacks.py
#        """
#        logger.info("%s: generating descriptions", split)
#
#        start_gen = self.args.generate_from_N_words  # Default 0
#        start_gen = start_gen + 1  # include BOS
#
#        # prepare the datastructures for generation (no batching over val)
#        if arrays == None:
#            arrays = self.make_generation_arrays(split, start_gen)
#        N_sents = arrays[0].shape[0]
#
#        complete_sentences = [[] for _ in range(N_sents)]
#        for t in range(start_gen):  # minimum 1
#            for i in range(N_sents):
#                w = np.argmax(arrays[0][i, t])
#                complete_sentences[i].append(self.data_generator.index2word[w])
#
#        for t in range(start_gen, self.args.generation_timesteps):
#            # we take a view of the datastructures, which means we're only
#            # ever generating a prediction for the next word. This saves a
#            # lot of cycles.
#            preds = self.full_model.predict([arr[:, 0:t] for arr in arrays],
#                                            verbose=0)
#
#            # Look at the last indices for the words.
#            next_word_indices = np.argmax(preds[:, -1], axis=1)
#            # update array[0]/sentence-so-far with generated words.
#            for i in range(N_sents):
#                arrays[0][i, t, next_word_indices[i]] = 1.
#            next_words = [self.data_generator.index2word[x] for x in next_word_indices]
#            for i in range(len(next_words)):
#                complete_sentences[i].append(next_words[i])
#
#        # extract each sentence until it hits the first end-of-string token
#        pruned_sentences = []
#        for s in complete_sentences:
#            pruned_sentences.append([x for x
#                                     in itertools.takewhile(
#                                         lambda n: n != "<E>", s)])
#        return pruned_sentences

    def to_h5_indices(self, split, indices, hidden_states):
        hsn_shape = len(hidden_states[0])
        fhf_str = "final_hidden_features"
        logger.info("Serialising final hidden state features from %s to H5",
                    split)
        for idx, data_key in enumerate(indices):
            ident = data_key[0]
            desc_idx = data_key[1]
            self.data_generator.set_source_features(split, ident,
                                                    self.h5_dataset_str,
                                                    hidden_states[idx],
                                                    hsn_shape,
                                                    desc_idx)

#    def serialise_to_h5_keys(self, split, data_keys, hidden_states):
#        hsn_shape = len(hidden_states[0])
#        fhf_str = "final_hidden_features"
#        logger.info("Serialising final hidden state features from %s to H5",
#                    split)
#        for idx, data_key in enumerate(data_keys):
#            self.data_generator.set_source_features(split, data_key,
#                                                    self.h5_dataset_str,
#                                                    hidden_states[idx],
#                                                    hsn_shape)
#            #try:
#            #    hsn_data = self.data_generator.dataset[split][data_key].create_dataset(
#            #        fhf_str, (hsn_shape,), dtype='float32')
#            #except RuntimeError:
#            #    # the dataset already exists, retrieve it into RAM and then overwrite it
#            #    del self.data_generator.dataset[split][data_key][fhf_str]
#            #    hsn_data = self.data_generator.dataset[split][data_key].create_dataset(
#            #        fhf_str, (hsn_shape,), dtype='float32')
#            #try:
#            #    hsn_data[:] = hidden_states[idx]
#            #except IndexError:
#            #    raise IndexError("data_key %s of %s; index idx %d, len hidden %d" % (
#            #        data_key, len(data_keys), idx, len(hidden_states)))
#            #    break
#
#    def sentences_to_h5(self, split, sentences):
#        '''
#        Save the predicted sentences into the h5 dataset object.
#        This is useful for subsequently (i.e. in a different program)
#        extracting LM-only final hidden states from predicted sentences.
#        Specifically, this can be compared to generating LM-only hidden
#        states over gold-standard tokens.
#        '''
#        idx = 0
#        logger.info("Serialising sentences from %s to H5", split)
#        data_keys = self.data_generator.dataset[split]
#        if split == 'val' and self.args.small_val:
#            data_keys = ["%06d" % x for x in range(len(sentences))]
#        else:
#            data_keys = ["%06d" % x for x in range(len(sentences))]
#        for data_key in data_keys:
#            self.data_generator.set_predicted_description(split, data_key,
#                                                          sentences[idx][1:])
#            idx += 1
#
#    def sentences_to_h5_keys(self, split, data_keys, sentences):
#        logger.info("Serialising sentences from %s to H5",
#                    split)
#        for idx, data_key in enumerate(data_keys):
#            self.data_generator.set_predicted_description(split, data_key,
#                                                    sentences[idx])
#
#    def serialise_to_h5(self, split, hsn_shape, hidden_states,
#                        batch_start=None, batch_end=None):
#        """ Serialise the hidden representations from generate_activations
#        into the h5 dataset.
#        This assumes one hidden_state per image key, which is maybe not
#        appropriate if there are multiple descriptions/image.
#        """
#        idx = 0
#        logger.info("Serialising final hidden state features from %s to H5",
#                    split)
#        if batch_start is not None:
#            logger.info("Start at %d, end at %d", batch_start, batch_end)
#            data_keys = ["%06d" % x for x in range(batch_start, batch_end)]
#            assert len(hidden_states) == len(data_keys),\
#                    "keys: %d hidden %d; start %d end %d" % (len(data_keys),
#                                            len(hidden_states), batch_start,
#                                            batch_end)
#        else:
#            data_keys = self.data_generator.dataset[split]
#            if split == 'val' and self.args.small_val:
#                data_keys = ["%06d" % x for x in range(len(hidden_states))]
#            else:
#                data_keys = ["%06d" % x for x in range(len(hidden_states))]
#        for data_key in data_keys:
#            self.data_generator.set_source_features(split, data_key,
#                                                    self.h5_dataset_str,
#                                                    hidden_states[idx],
#                                                    hsn_shape)
#            #try:
#            #    hsn_data = self.data_generator.dataset[split][data_key].create_dataset(
#            #        fhf_str, (hsn_shape,), dtype='float32')
#            #except RuntimeError:
#            #    # the dataset already exists, retrieve it into RAM and then overwrite it
#            #    del self.data_generator.dataset[split][data_key][fhf_str]
#            #    hsn_data = self.data_generator.dataset[split][data_key].create_dataset(
#            #        fhf_str, (hsn_shape,), dtype='float32')
#            #try:
#            #    hsn_data[:] = hidden_states[idx]
#            #except IndexError:
#            #    raise IndexError("data_key %s of %s; index idx %d, len hidden %d" % (
#            #        data_key, len(data_keys),
#            #                      idx, len(hidden_states)))
#            #    break
#            idx += 1

    def find_best_checkpoint(self):
        '''
        Read the summary file from the directory and scrape out the run ID of
        the highest BLEU scoring checkpoint. Then do an ls-stlye function in
        the directory and return the exact path to the best model.

        Assumes only one matching prefix in the model checkpoints directory.
        '''

        summary_data = open("%s/summary" % self.args.model_checkpoints).readlines()
        summary_data = [x.replace("\n", "") for x in summary_data]
        best_id = None
        target = "Best loss" if self.args.best_pplx else "Best Metric"
        for line in summary_data:
            if line.startswith(target):
                best_id = "%03d" % (int(line.split(":")[1].split("|")[0]))

        checkpoint = None
        if best_id is not None:
            checkpoints = os.listdir(self.args.model_checkpoints)
            for c in checkpoints:
                if c.startswith(best_id):
                    checkpoint = c
                    break
        return "%s/%s" % (self.args.model_checkpoints, checkpoint)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="""
                                     Serialise final RNN hidden state vector
                                     for each instance in a dataset.""")

    # General options
    parser.add_argument("--run_string", default="", type=str,
                        help="Optional string to help you identify the run")
    parser.add_argument("--debug", action="store_true",
                        help="Print debug messages to stdout?")
    parser.add_argument("--init_from_checkpoint", help="Initialise the model\
                        parameters from a pre-defined checkpoint? Useful to\
                        continue training a model.", default=None, type=str)
    parser.add_argument("--fixed_seed", action="store_true",
                        help="Start with a fixed random seed? Useful for\
                        reproding experiments. (default = False)")
    parser.add_argument("--num_sents", default=5, type=int,
                        help="Number of descriptions/image for training")
    parser.add_argument("--model_checkpoints", type=str, required=True,
                        help="Path to the checkpointed parameters")
    parser.add_argument("--best_pplx", action="store_true",
                        help="Use the best PPLX checkpoint instead of the\
                        best BLEU checkpoint? Default = False.")

    # Define the types of input data the model will receive
    parser.add_argument("--dataset", default="", type=str, help="Path to the\
                        HDF5 dataset to use for training / val input\
                        (defaults to flickr8k)")
    parser.add_argument("--supertrain_datasets", nargs="+", help="Paths to the\
                        datasets to use as additional training input (defaults\
                        to None)")
    parser.add_argument("--unk", type=int,
                        help="unknown character cut-off. Default=3", default=3)
    parser.add_argument("--maximum_length", type=int, default=50,
                        help="Maximum length of sequences permissible\
			in the training data (Default = 50)")
    parser.add_argument("--existing_vocab", type=str, default="",
                        help="Use an existing vocabulary model to define the\
                        vocabulary and UNKing in this dataset?\
                        (default = "", which means we will derive the\
                        vocabulary from the training dataset")
    parser.add_argument("--no_image", action="store_true",
                        help="Do not use image data.")
    parser.add_argument("--source_vectors", default=None, type=str,
                        help="Path to final hidden representations of\
                        encoder/source language VisualWordLSTM model.\
                        (default: None.) Expects a final_hidden_representation\
                        vector for each image in the dataset")
    parser.add_argument("--source_enc", type=str, default=None,
                        help="Which type of source encoder features? Expects\
                        either 'mt_enc' or 'vis_enc'. Required.")
    parser.add_argument("--source_type", type=str, default=None,
                        help="Source features over gold or predicted tokens?\
                        Expects 'gold' or 'predicted'. Required")
    parser.add_argument("--source_merge", type=str, default="sum",
                        help="How to merge source features. Only applies if \
                        there are multiple feature vectors. Expects 'sum', \
                        'avg', or 'concat'.")

    # Model hyperparameters
    parser.add_argument("--batch_size", default=100, type=int)
    parser.add_argument("--embed_size", default=256, type=int)
    parser.add_argument("--hidden_size", default=256, type=int)
    parser.add_argument("--dropin", default=0.5, type=float,
                        help="Prob. of dropping embedding units. Default=0.5")
    parser.add_argument("--gru", action="store_true", help="Use GRU instead\
                        of LSTM recurrent state? (default = False)")
    parser.add_argument("--mrnn", action="store_true", 
                        help="Use a Mao-style multimodal recurrent neural\
                        network?")
    parser.add_argument("--peeking_source", action="store_true",
                        help="Input the source features at every timestep?\
                        Default=False.")

    # Optimisation details
    parser.add_argument("--optimiser", default="adam", type=str,
                        help="Optimiser: rmsprop, momentum, adagrad, etc.")
    parser.add_argument("--lr", default=0.001, type=float)
    parser.add_argument("--beta1", default=None, type=float)
    parser.add_argument("--beta2", default=None, type=float)
    parser.add_argument("--epsilon", default=None, type=float)
    parser.add_argument("--stopping_loss", default="bleu", type=str,
                        help="minimise cross-entropy or maximise BLEU?")
    parser.add_argument("--l2reg", default=1e-8, type=float,
                        help="L2 cost penalty. Default=1e-8")
    parser.add_argument("--clipnorm", default=-1, type=float,
                        help="Clip gradients? (default = -1, which means\
                        don't clip the gradients.")
    parser.add_argument("--max_epochs", default=50, type=int,
                        help="Maxmimum number of training epochs. Used with\
                        --predefined_epochs")
    parser.add_argument("--patience", type=int, default=10, help="Training\
                        will be terminated if validation BLEU score does not\
                        increase for this number of epochs")
    parser.add_argument("--no_early_stopping", action="store_true")

    # Language generation details
    parser.add_argument("--generation_timesteps", default=30, type=int,
                        help="Maximum number of words to generate for unseen\
                        data (default=10).")

    # Legacy options
    parser.add_argument("--generate_from_N_words", type=int, default=0,
                        help="Use N words as starting point when generating\
                        strings. Useful mostly for mt-only model (in other\
                        cases, image provides enough useful starting\
                        context.)")
    parser.add_argument("--predefined_epochs", action="store_true",
                        help="Do you want to stop training after a specified\
                        number of epochs, regardless of early-stopping\
                        criteria? Use in conjunction with --max_epochs.")

    parser.add_argument("--h5_writeable", action="store_true",
                        help="Open the H5 file for write-access? Useful for\
                        serialising hidden states to disk. (default = False)")
    parser.add_argument("--use_predicted_tokens", action="store_true",
                        help="Generate final hidden state\
                        activations over oracle inputs or from predicted\
                        inputs? Default = False ( == Oracle)")

    w = ExtractFinalHiddenStateActivations(parser.parse_args())
    w.get_hidden_activations()