# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, LSTMCell, RNNCell
from tensorflow.contrib.lookup import MutableHashTable
from .seq_helper import sequence_loss, sentence_ppx
from .attention_decoder import create_output_fn, prepare_attention, \
    attention_decoder_train, attention_decoder_inference
from .dynamic_decoder import dynamic_rnn_decoder
from ..dataset.knowledge_loader import UNK_ID, GO_ID, EOS_ID, NONE_ID


class TransDGModel(object):
    def __init__(self, word_embed, kd_embed, param_dict, use_trans_repr=True, use_trans_select=True,
                 vocab_size=30000, dim_emb=300, dim_trans=100, cell_class='GRU', num_units=512, num_layers=2,
                 max_length=60, lr_rate=0.0001, max_grad_norm=5.0, drop_rate=0.2, beam_size=1):
        # initialize params
        self.use_trans_repr = use_trans_repr
        self.use_trans_select = use_trans_select
        self.vocab_size = vocab_size
        self.dim_emb = dim_emb
        self.dim_trans = dim_trans
        self.cell_class = cell_class
        self.num_units = num_units
        self.num_layers = num_layers
        self.lr_rate = lr_rate
        self.max_grad_norm = max_grad_norm
        self.drop_rate = drop_rate
        self.max_length = max_length
        self.beam_size = beam_size

        self.global_step = tf.Variable(0, trainable=False, name="global_step")
        self._init_embed(word_embed, kd_embed)
        self._init_placeholders()
        self._init_vocabs()

        self.select_mode = None
        if self.use_trans_select:
            self.select_layer = self._init_select_layer(param_dict=param_dict)
        else:
            self.select_layer = None

        # build model
        self.ppx_loss, self.loss = self.build_model(train_mode=True)
        self.generation = self.build_model(train_mode=False)

        # construct graphs for minimizing loss
        optimizer = tf.train.AdamOptimizer(learning_rate=self.lr_rate)
        self.params = tf.global_variables()
        gradients = tf.gradients(self.loss, self.params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, self.max_grad_norm)
        self.update = optimizer.apply_gradients(zip(clipped_gradients, self.params), global_step=self.global_step)

    def _init_vocabs(self):
        self.symbol2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64,
                                             default_value=UNK_ID, shared_name="w2id_table",
                                             name="w2id_table", checkpoint=True)
        self.index2symbol = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string,
                                             default_value='_UNK', shared_name="id2w_table",
                                             name="id2w_table", checkpoint=True)
        self.kd2index = MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64,
                                         default_value=NONE_ID, shared_name="kd2id_table",
                                         name="kd2id_table", checkpoint=True)
        self.index2kd = MutableHashTable(key_dtype=tf.int64, value_dtype=tf.string,
                                         default_value='_NONE', shared_name="id2kd_table",
                                         name="id2kd_table", checkpoint=True)

    def _init_placeholders(self):
        self.posts = tf.placeholder(tf.string, (None, None), 'post')  # [batch, len]
        self.posts_length = tf.placeholder(tf.int32, (None), 'post_lens')  # batch
        self.responses = tf.placeholder(tf.string, (None, None), 'resp')  # [batch, len]
        self.responses_length = tf.placeholder(tf.int32, (None), 'resp_lens')  # batch
        self.corr_responses = tf.placeholder(tf.string, (None, None, None), 'corr_resps')  # [batch, topk, len]
        self.triples = tf.placeholder(tf.string, (None, None, None, 3), 'triples')
        self.trans_reprs = tf.placeholder(tf.float32, (None, None, self.num_units), 'trans_reprs')

    def _init_embed(self, word_embed, kd_embed=None):
        self.word_embed = tf.get_variable('word_embed', dtype=tf.float32,
                                          initializer=word_embed, trainable=False)  # [vocab_size, dim_emb]

    def _init_select_layer(self, param_dict):
        """
        :param param_dict: type dict
        :return: Defined bilinear layer or mlp layer
        """
        if "bilinear_mat" in param_dict.keys():
            self.select_mode = 'bilinear'

            def bilinear_layer(inputs1, inputs2, trainable=False):
                bilinear_mat = tf.get_variable('bilinear_mat', dtype=tf.float32,
                                               initializer=param_dict['bilinear_mat'], trainable=trainable)
                proj_repr = tf.matmul(inputs2, bilinear_mat)
                scores = tf.reduce_sum(inputs1 * proj_repr, axis=-1)
                return scores
            return bilinear_layer
        else:
            self.select_mode = 'mlp'

            def fully_connected_layer(inputs, trainable=False):
                fc1_weights = tf.get_variable('fc1_weights', dtype=tf.float32,
                                              initializer=param_dict['fc1_weights'], trainable=trainable)
                fc1_biases = tf.get_variable('fc1_biases', dtype=tf.float32,
                                             initializer=param_dict['fc1_biases'], trainable=trainable)
                hidden_outs = tf.nn.relu(tf.matmul(inputs, fc1_weights) + fc1_biases)

                fc2_weights = tf.get_variable('fc2_weights', dtype=tf.float32,
                                              initializer=param_dict['fc2_weights'], trainable=trainable)
                fc2_biases = tf.get_variable('fc2_biases', dtype=tf.float32,
                                             initializer=param_dict['fc2_biases'], trainable=trainable)
                scores = tf.matmul(hidden_outs, fc2_weights) + fc2_biases
                return scores
            return fully_connected_layer

    def build_model(self, train_mode=True):
        # build the vocab table (string to index)
        batch_size = tf.shape(self.posts)[0]
        post_word_id = self.symbol2index.lookup(self.posts)
        post_word_input = tf.nn.embedding_lookup(self.word_embed, post_word_id)  # batch*len*unit

        corr_responses_id = self.symbol2index.lookup(self.corr_responses)  # [batch, topk, len]
        corr_responses_input = tf.nn.embedding_lookup(self.word_embed, corr_responses_id)  # [batch, topk, len, unit]

        triple_id = self.symbol2index.lookup(self.triples)
        triple_input = tf.nn.embedding_lookup(self.word_embed, triple_id)
        triple_num = tf.shape(self.triples)[1]
        triple_input = tf.reshape(triple_input, [batch_size, triple_num, -1, 3 * self.dim_emb])
        triple_input = tf.reduce_mean(triple_input, axis=2)  # [batch, triple_num, 3*dim_emb]

        resp_target = self.symbol2index.lookup(self.responses)
        decoder_len = tf.shape(self.responses)[1]
        resp_word_id = tf.concat([tf.ones([batch_size, 1], dtype=tf.int64) * GO_ID,
                                  tf.split(resp_target, [decoder_len - 1, 1], 1)[0]], 1)  # [batch,len]
        resp_word_input = tf.nn.embedding_lookup(self.word_embed, resp_word_id) 
        decoder_mask = tf.reshape(tf.cumsum(
            tf.one_hot(self.responses_length - 1, decoder_len), reverse=True, axis=1),
            [-1, decoder_len])

        encoder_output, encoder_state = self.build_encoder(post_word_input,
                                                           corr_responses_input)

        if train_mode:
            output_logits = self.build_decoder(encoder_output, encoder_state,
                                               triple_input, resp_word_input,
                                               train_mode=train_mode)
            sent_ppx = sentence_ppx(self.vocab_size, output_logits, resp_target, decoder_mask)
            seq_loss = sequence_loss(self.vocab_size, output_logits, resp_target, decoder_mask)
            ppx_loss = tf.identity(sent_ppx, name="ppx_loss")
            loss = tf.identity(seq_loss, name="loss")
            return ppx_loss, loss
        else:
            decoder_dist = self.build_decoder(encoder_output, encoder_state,
                                              triple_input, decoder_input=None,
                                              train_mode=train_mode)
            generation_index = tf.argmax(decoder_dist, 2)
            generation = self.index2symbol.lookup(generation_index)
            generation = tf.identity(generation, name='generation')
            return generation

    def build_encoder(self, post_word_input, corr_responses_input):
        if self.cell_class == 'GRU':
            encoder_cell = MultiRNNCell([GRUCell(self.num_units) for _ in range(self.num_layers)])
        elif self.cell_class == 'LSTM':
            encoder_cell = MultiRNNCell([LSTMCell(self.num_units) for _ in range(self.num_layers)])
        else:
            encoder_cell = MultiRNNCell([RNNCell(self.num_units) for _ in range(self.num_layers)])

        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE) as scope:
            encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell,
                                                              post_word_input,
                                                              self.posts_length,
                                                              dtype=tf.float32, scope=scope)
        batch_size, encoder_len = tf.shape(self.posts)[0], tf.shape(self.posts)[1]
        corr_response_input = tf.reshape(corr_responses_input, [batch_size, -1, self.dim_emb])
        corr_cum_len = tf.shape(corr_response_input)[1]
        with tf.variable_scope('mutual_attention', reuse=tf.AUTO_REUSE):
            encoder_out_trans = tf.layers.dense(encoder_output, self.num_units,
                                                name='encoder_out_transform')
            corr_response_trans = tf.layers.dense(corr_response_input, self.num_units,
                                                  name='corr_response_transform')
            encoder_out_trans = tf.expand_dims(encoder_out_trans, axis=1)
            encoder_out_trans = tf.tile(encoder_out_trans, [1, corr_cum_len, 1, 1])
            encoder_out_trans = tf.reshape(encoder_out_trans, [-1, encoder_len, self.num_units])

            corr_response_trans = tf.reshape(corr_response_trans, [-1, self.num_units])
            corr_response_trans = tf.expand_dims(corr_response_trans, axis=1)

            # TODO: try bilinear attention
            v = tf.get_variable("attention_v", [self.num_units], dtype=tf.float32)
            score = tf.reduce_sum(v * tf.tanh(encoder_out_trans + corr_response_trans), axis=2)
            alignments = tf.nn.softmax(score)

            encoder_out_tiled = tf.expand_dims(encoder_output, axis=1)
            encoder_out_tiled = tf.tile(encoder_out_tiled, [1, corr_cum_len, 1, 1])
            encoder_out_tiled = tf.reshape(encoder_out_tiled, [-1, encoder_len, self.num_units])

            context_mutual = tf.reduce_sum(tf.expand_dims(alignments, 2) * encoder_out_tiled, axis=1)
            context_mutual = tf.reshape(context_mutual, [batch_size, -1, self.num_units])
            context_mutual = tf.reduce_mean(context_mutual, axis=1)
    
        encoder_output = tf.concat([encoder_output, tf.expand_dims(context_mutual, 1)], axis=1)

        if self.use_trans_repr:
            trans_output = tf.layers.dense(self.trans_reprs, self.num_units,
                                           name='trans_reprs_transform', reuse=tf.AUTO_REUSE)
            encoder_output = tf.concat([encoder_output, trans_output], axis=1)

        return encoder_output, encoder_state

    def build_decoder(self, encoder_output, encoder_state, triple_input, decoder_input, train_mode=True):
        if self.cell_class == 'GRU':
            decoder_cell = MultiRNNCell([GRUCell(self.num_units) for _ in range(self.num_layers)])
        elif self.cell_class == 'LSTM':
            decoder_cell = MultiRNNCell([LSTMCell(self.num_units) for _ in range(self.num_layers)])
        else:
            decoder_cell = MultiRNNCell([RNNCell(self.num_units) for _ in range(self.num_layers)])

        if train_mode:
            with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE) as scope:
                if self.use_trans_select:
                    kd_context = self.transfer_matching(encoder_output, triple_input)
                else:
                    kd_context = None
                # prepare attention
                attention_keys, attention_values, attention_construct_fn \
                    = prepare_attention(encoder_output, kd_context, 'bahdanau', self.num_units)
                decoder_fn_train = attention_decoder_train(
                    encoder_state=encoder_state,
                    attention_keys=attention_keys,
                    attention_values=attention_values,
                    attention_construct_fn=attention_construct_fn)
                # train decoder
                decoder_output, _, _ = dynamic_rnn_decoder(cell=decoder_cell,
                                                           decoder_fn=decoder_fn_train,
                                                           inputs=decoder_input,
                                                           sequence_length=self.responses_length,
                                                           scope=scope)
                output_fn = create_output_fn(vocab_size=self.vocab_size)
                output_logits = output_fn(decoder_output)
                return output_logits
        else:
            with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE) as scope:
                if self.use_trans_select:
                    kd_context = self.transfer_matching(encoder_output, triple_input)
                else:
                    kd_context = None
                attention_keys, attention_values, attention_construct_fn \
                    = prepare_attention(encoder_output, kd_context, 'bahdanau', self.num_units, reuse=tf.AUTO_REUSE)
                output_fn = create_output_fn(vocab_size=self.vocab_size)
                # inference decoder
                decoder_fn_inference = attention_decoder_inference(
                    num_units=self.num_units, num_decoder_symbols=self.vocab_size,
                    output_fn=output_fn, encoder_state=encoder_state,
                    attention_keys=attention_keys, attention_values=attention_values,
                    attention_construct_fn=attention_construct_fn, embeddings=self.word_embed,
                    start_of_sequence_id=GO_ID, end_of_sequence_id=EOS_ID, maximum_length=self.max_length)

                # get decoder output
                decoder_distribution, _, _ = dynamic_rnn_decoder(cell=decoder_cell,
                                                                 decoder_fn=decoder_fn_inference,
                                                                 scope=scope)
                return decoder_distribution

    def transfer_matching(self, context_repr, knowledge_repr):
        context = tf.reduce_mean(context_repr, axis=1)  # [batch, num_units]
        triple_num = tf.shape(self.triples)[1]
        context_tile = tf.tile(tf.expand_dims(context, axis=1), [1, triple_num, 1])  # [batch, triple_num, num_units]
        knowledge = tf.layers.dense(knowledge_repr, self.dim_emb,
                                    name='knowledge_transform')  # [batch, triple_num, dim_emb]

        if self.select_mode == 'bilinear':
            context_reshaped = tf.reshape(context_tile, [-1, self.num_units])
            knowledge_reshaped = tf.reshape(knowledge, [-1, self.dim_emb])
            em_scores = self.select_layer(context_reshaped, knowledge_reshaped)
        else:
            concat_repr = tf.concat([context_tile, knowledge], axis=-1)  # [batch, triple_num, num_units+dim_emb]
            concat_repr_reshaped = tf.reshape(concat_repr, [-1,self.num_units + self.dim_emb])  # [batch*triple_num, num_units+dim_emb]
            em_scores = self.select_layer(concat_repr_reshaped)

        batch_size = tf.shape(self.posts)[0]
        em_scores = tf.reshape(em_scores, [batch_size, triple_num])
        kd_context = tf.matmul(tf.expand_dims(em_scores, axis=1), knowledge)
        kd_context = tf.reshape(kd_context, [batch_size, self.dim_emb])

        return kd_context

    def set_vocabs(self, session, vocab, kd_vocab):
        op_in = self.symbol2index.insert(tf.constant(vocab),
                                         tf.constant(list(range(self.vocab_size)), dtype=tf.int64))
        session.run(op_in)
        op_out = self.index2symbol.insert(tf.constant(list(range(self.vocab_size)), dtype=tf.int64),
                                          tf.constant(vocab))
        session.run(op_out)
        op_in = self.kd2index.insert(tf.constant(kd_vocab),
                                     tf.constant(list(range(len(kd_vocab))), dtype=tf.int64))
        session.run(op_in)
        op_out = self.index2kd.insert(tf.constant(list(range(len(kd_vocab))), dtype=tf.int64),
                                      tf.constant(kd_vocab))
        session.run(op_out)

    def show_parameters(self):
        for var in self.params:
            print("%s: %s" % (var.name, var.get_shape().as_list()))

    def train_batch(self, session, data, trans_reprs):
        input_feed = {self.posts: data['post'],
                      self.posts_length: data['post_len'],
                      self.responses: data['response'],
                      self.responses_length: data['response_len'],
                      self.corr_responses: data['corr_responses'],
                      self.triples: data['all_triples']
                      }
        if self.use_trans_repr:
            input_feed[self.trans_reprs] = trans_reprs

        output_feed = [self.ppx_loss, self.loss, self.update]
        outputs = session.run(output_feed, feed_dict=input_feed)
        return outputs[0], outputs[1]

    def eval_batch(self, session, data, trans_reprs):
        input_feed = {self.posts: data['post'],
                      self.posts_length: data['post_len'],
                      self.responses: data['response'],
                      self.responses_length: data['response_len'],
                      self.corr_responses: data['corr_responses'],
                      self.triples: data['all_triples']
                      }
        if self.use_trans_repr:
            input_feed[self.trans_reprs] = trans_reprs

        output_feed = [self.ppx_loss, self.loss]
        outputs = session.run(output_feed, feed_dict=input_feed)
        return outputs[0], outputs[1]

    def decode_batch(self, session, data, trans_reprs):
        input_feed = {self.posts: data['post'],
                      self.posts_length: data['post_len'],
                      self.responses: data['response'],
                      self.responses_length: data['response_len'],
                      self.corr_responses: data['corr_responses'],
                      self.triples: data['all_triples']
                      }
        if self.use_trans_repr:
            input_feed[self.trans_reprs] = trans_reprs

        output_feed = [self.generation, self.ppx_loss, self.loss]
        outputs = session.run(output_feed, input_feed)
        return outputs[0], outputs[1], outputs[-1]