# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorflow.contrib.legacy_seq2seq as seq2seq
import toolbox
import batch as Batch
import numpy as np
import cPickle as pickle
import evaluation

import os

class Seq2seq(object):

    def __init__(self, trained_model):
        self.en_vec = None
        self.de_vec = None
        self.trans_output = None
        self.trans_labels = None
        self.feed_previouse = None
        self.trans_l_rate = None
        self.trained = trained_model
        self.decode_step = None
        self.encode_step = None

    def define(self, char_num, rnn_dim, emb_dim, max_x, max_y, write_trans_model=True):
        self.decode_step = max_y
        self.encode_step = max_x
        self.en_vec = [tf.placeholder(tf.int32, [None], name='en_input' + str(i)) for i in range(max_x)]
        self.trans_labels = [tf.placeholder(tf.int32, [None], name='de_input' + str(i)) for i in range(max_y)]
        weights = [tf.cast(tf.sign(ot_t), tf.float32) for ot_t in self.trans_labels]
        self.de_vec = [tf.zeros_like(self.trans_labels[0], tf.int32)] + self.trans_labels[:-1]
        self.feed_previous = tf.placeholder(tf.bool)
        self.trans_l_rate = tf.placeholder(tf.float32, [], name='learning_rate')
        seq_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, state_is_tuple=True)
        self.trans_output, states = seq2seq.embedding_attention_seq2seq(self.en_vec, self.de_vec, seq_cell, char_num,
                                                                        char_num, emb_dim, feed_previous=self.feed_previous)

        loss = seq2seq.sequence_loss(self.trans_output, self.trans_labels, weights)
        optimizer = tf.train.AdagradOptimizer(learning_rate=self.trans_l_rate)

        params = tf.trainable_variables()
        gradients = tf.gradients(loss, params)
        clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5.0)
        self.trans_train = optimizer.apply_gradients(zip(clipped_gradients, params))

        self.saver = tf.train.Saver()

        if write_trans_model:
            param_dic = {}
            param_dic['char_num'] = char_num
            param_dic['rnn_dim'] = rnn_dim
            param_dic['emb_dim'] = emb_dim
            param_dic['max_x'] = max_x
            param_dic['max_y'] = max_y
            # print param_dic
            f_model = open(self.trained + '_model', 'w')
            pickle.dump(param_dic, f_model)
            f_model.close()

    def train(self, t_x, t_y, v_x, v_y, lrv, char2idx, sess, epochs, batch_size=10, reset=True):

        idx2char = {k: v for v, k in char2idx.items()}
        v_y_g = [np.trim_zeros(v_y_t) for v_y_t in v_y]
        gold_out = [toolbox.generate_trans_out(v_y_t, idx2char) for v_y_t in v_y_g]

        best_score = 0

        if reset or not os.path.isfile(self.trained + '_weights.index'):
            for epoch in range(epochs):
                Batch.train_seq2seq(sess, model=self.en_vec + self.trans_labels, decoding=self.feed_previous,
                                    batch_size=batch_size, config=self.trans_train, lr=self.trans_l_rate, lrv=lrv,
                                    data=[t_x] + [t_y])
                pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output,
                                             decoding=self.feed_previous, decode_len=self.decode_step,
                                             data=[v_x], argmax=True, batch_size=100)
                pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred]

                c_scores = evaluation.trans_evaluator(gold_out, pred_out)

                print 'epoch: %d' % (epoch + 1)

                print 'ACC: %f' % c_scores[0]
                print 'Token F score: %f' % c_scores[1]

                if c_scores[1] > best_score:
                    best_score = c_scores[1]
                    self.saver.save(sess, self.trained + '_weights', write_meta_graph=False)

        if best_score > 0 or not reset:
            self.saver.restore(sess, self.trained + '_weights')

    def tag(self, t_x, char2idx, sess, batch_size=100):

        t_x = [t_x_t[:self.encode_step] for t_x_t in t_x]
        t_x = toolbox.pad_zeros(t_x, self.encode_step)

        idx2char = {k: v for v, k in char2idx.items()}

        pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output, decoding=self.feed_previous,
                                         decode_len=self.decode_step, data=[t_x], argmax=True, batch_size=batch_size)
        pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred]

        return pred_out