# -*- coding: utf-8 -*-

from __future__ import print_function

"""
Interactive evaluation for the RTE networks.
"""

import argparse
import tensorflow as tf
import numpy as np
import matplotlib
matplotlib.use('TKAgg')  # necessary on OS X
from matplotlib import pyplot as pl

from classifiers import multimlp
import utils
import ioutils


class SentenceWrapper(object):
    """
    Class for the basic sentence preprocessing needed to make it readable
    by the networks.
    """
    def __init__(self, sentence, word_dict, lowercase, language='en'):
        self.sentence = sentence
        tokenize = utils.get_tokenizer(language)
        if lowercase:
            pre_tokenize = sentence.lower()
        else:
            pre_tokenize = sentence
        self.tokens = tokenize(pre_tokenize)
        self.indices = [word_dict[token] for token in self.tokens_with_null]
        self.padding_index = word_dict[utils.PADDING]

    def __len__(self):
        return len(self.tokens)

    @property
    def tokens_with_null(self):
        return [utils.GO] + self.tokens

    def convert_sentence(self):
        """
        Convert a sequence of tokens into the input array used by the network
        :return: the vector to be given to the network
        """
        indices = np.array(self.indices)
        # padded = np.pad(indices, (0, num_time_steps - len(indices)),
        #                 'constant', constant_values=self.padding_index)
        return indices.reshape((1, -1))


def print_attention(tokens1, tokens2, attention):
    """
    Print the attention from tokens1 over tokens2
    """
    # multiply by 10 to make it easier to visualize
    attention_bigger = attention * 10
    max_length_sent1 = max([len(t) for t in tokens1])

    # create formatting string to match the size of the tokens
    att_formatters = ['{:>%d.2f}' % len(t) for t in tokens2]

    # first line has whitespace in the first sentence column and
    # then the second one
    blank = ' ' * max_length_sent1

    # take at least length 4 to fit the 9.99 format
    formatted_sent2 = ['{:>4}'.format(token) for token in tokens2]
    first_line = blank + '\t' + '\t'.join(formatted_sent2)
    print(first_line)

    for token, att in zip(tokens1, attention_bigger):

        values = [fmt.format(x)
                  for x, fmt in zip(att, att_formatters)]
        fmt_str = '{:>%d}' % max_length_sent1
        formatted_token = fmt_str.format(token)
        line = formatted_token + '\t' + '\t'.join(values)
        print (line)


def plot_attention(tokens1, tokens2, attention):
    """
    Print a colormap showing attention values from tokens 1 to
    tokens 2.
    """
    len1 = len(tokens1)
    len2 = len(tokens2)
    extent = [0, len2, 0, len1]
    pl.matshow(attention, extent=extent, aspect='auto')
    ticks1 = np.arange(len1) + 0.5
    ticks2 = np.arange(len2) + 0.5
    pl.xticks(ticks2, tokens2, rotation=45)
    pl.yticks(ticks1, reversed(tokens1))
    ax = pl.gca()
    ax.xaxis.set_ticks_position('bottom')
    pl.colorbar()
    pl.title('Alignments')
    pl.show(block=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('load', help='Directory with saved model files')
    parser.add_argument('embeddings', help='Text or numpy file with word embeddings')
    parser.add_argument('--vocab', help='Vocabulary file (only needed if numpy'
                                        'embedding file is given)')
    parser.add_argument('-a', help='Plot attention values graph', dest='attention',
                        action='store_true')
    parser.add_argument('-i', help='Run inference classifier', dest='inference',
                        action='store_true')
    args = parser.parse_args()

    utils.config_logger(verbose=False)
    logger = utils.get_logger()
    params = ioutils.load_params(args.load)
    if args.inference:
        label_dict = ioutils.load_label_dict(args.load)
        number_to_label = {v: k for (k, v) in label_dict.items()}

    logger.info('Reading model')
    sess = tf.InteractiveSession()
    model_class = utils.get_model_class(params)
    model = model_class.load(args.load, sess)
    word_dict, embeddings = ioutils.load_embeddings(args.embeddings, args.vocab,
                                                    generate=False,
                                                    load_extra_from=args.load,
                                                    normalize=True)
    model.initialize_embeddings(sess, embeddings)

    ops = []
    if args.inference:
        ops.append(model.answer)
    if args.attention:
        ops.append(model.inter_att1)
        ops.append(model.inter_att2)

    while True:
        sent1 = raw_input('Type sentence 1: ').decode('utf-8')
        sent2 = raw_input('Type sentence 2: ').decode('utf-8')
        sent1 = SentenceWrapper(sent1, word_dict,
                                params['lowercase'], params['language'])
        sent2 = SentenceWrapper(sent2, word_dict,
                                params['lowercase'], params['language'])

        vector1 = sent1.convert_sentence()
        vector2 = sent2.convert_sentence()
        size1 = len(sent1.tokens_with_null)
        size2 = len(sent2.tokens_with_null)

        feeds = {model.sentence1: vector1,
                 model.sentence2: vector2,
                 model.sentence1_size: [size1],
                 model.sentence2_size: [size2],
                 model.dropout_keep: 1.0}

        results = sess.run(ops, feed_dict=feeds)
        if args.inference:
            answer = results.pop(0)
            print('Model answer:', number_to_label[answer[0]])

        if args.attention:
            att1 = results.pop(0)
            att2 = results.pop(0)
            print('Attention sentence 1:')
            print_attention(sent1.tokens_with_null,
                            sent2.tokens_with_null, att1[0])
            plot_attention(sent1.tokens_with_null,
                           sent2.tokens_with_null, att1[0])
            print('Attention sentence 2:')
            print_attention(sent2.tokens_with_null,
                            sent1.tokens_with_null, att2[0])
            plot_attention(sent2.tokens_with_null,
                           sent1.tokens_with_null, att2[0])

        print()