import tensorflow as tf
import numpy as np
from tensorflow.contrib.crf import viterbi_decode, crf_log_likelihood
from tensorflow.python.ops.rnn_cell import LSTMCell, GRUCell, MultiRNNCell
from utils import CoNLLeval, load_dataset, get_logger, process_batch_data, align_data
from utils.common import word_convert, UNK
import os


class BaseModel:
    def __init__(self, config):
        self.cfg = config
        self._initialize_config()
        self.sess, self.saver = None, None
        self._add_placeholders()
        self._build_embedding_op()
        self._build_model_op()
        self._build_loss_op()
        self._build_train_op()
        print('params number: {}'.format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])))
        self.initialize_session()

    def _initialize_config(self):
        # create folders and logger
        if not os.path.exists(self.cfg["checkpoint_path"]):
            os.makedirs(self.cfg["checkpoint_path"])
        if not os.path.exists(self.cfg["summary_path"]):
            os.makedirs(self.cfg["summary_path"])
        self.logger = get_logger(os.path.join(self.cfg["checkpoint_path"], "log.txt"))
        # load dictionary
        dict_data = load_dataset(self.cfg["vocab"])
        self.word_dict, self.char_dict = dict_data["word_dict"], dict_data["char_dict"]
        self.tag_dict = dict_data["tag_dict"]
        del dict_data
        self.word_vocab_size = len(self.word_dict)
        self.char_vocab_size = len(self.char_dict)
        self.tag_vocab_size = len(self.tag_dict)
        self.rev_word_dict = dict([(idx, word) for word, idx in self.word_dict.items()])
        self.rev_char_dict = dict([(idx, char) for char, idx in self.char_dict.items()])
        self.rev_tag_dict = dict([(idx, tag) for tag, idx in self.tag_dict.items()])

    def initialize_session(self):
        sess_config = tf.ConfigProto()
        sess_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=sess_config)
        self.saver = tf.train.Saver(max_to_keep=self.cfg["max_to_keep"])
        self.sess.run(tf.global_variables_initializer())

    def restore_last_session(self, ckpt_path=None):
        if ckpt_path is not None:
            ckpt = tf.train.get_checkpoint_state(ckpt_path)
        else:
            ckpt = tf.train.get_checkpoint_state(self.cfg["checkpoint_path"])  # get checkpoint state
        if ckpt and ckpt.model_checkpoint_path:  # restore session
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)

    def save_session(self, epoch):
        self.saver.save(self.sess, self.cfg["checkpoint_path"] + self.cfg["model_name"], global_step=epoch)

    def close_session(self):
        self.sess.close()

    def _add_summary(self):
        self.summary = tf.summary.merge_all()
        self.train_writer = tf.summary.FileWriter(self.cfg["summary_path"] + "train", self.sess.graph)
        self.test_writer = tf.summary.FileWriter(self.cfg["summary_path"] + "test")

    def reinitialize_weights(self, scope_name=None):
        """Reinitialize parameters in a scope"""
        if scope_name is None:
            self.sess.run(tf.global_variables_initializer())
        else:
            variables = tf.contrib.framework.get_variables(scope_name)
            self.sess.run(tf.variables_initializer(variables))

    @staticmethod
    def variable_summaries(variable, name=None):
        with tf.name_scope(name or "summary"):
            mean = tf.reduce_mean(variable)
            tf.summary.scalar("mean", mean)  # add mean value
            stddev = tf.sqrt(tf.reduce_mean(tf.square(variable - mean)))
            tf.summary.scalar("stddev", stddev)  # add standard deviation value
            tf.summary.scalar("max", tf.reduce_max(variable))  # add maximal value
            tf.summary.scalar("min", tf.reduce_min(variable))  # add minimal value
            tf.summary.histogram("histogram", variable)  # add histogram

    @staticmethod
    def viterbi_decode(logits, trans_params, seq_len):
        viterbi_sequences = []
        for logit, lens in zip(logits, seq_len):
            logit = logit[:lens]  # keep only the valid steps
            viterbi_seq, viterbi_score = viterbi_decode(logit, trans_params)
            viterbi_sequences += [viterbi_seq]
        return viterbi_sequences

    def _create_single_rnn_cell(self, num_units):
        cell = GRUCell(num_units) if self.cfg["cell_type"] == "gru" else LSTMCell(num_units)
        return cell

    def _create_rnn_cell(self):
        if self.cfg["num_layers"] is None or self.cfg["num_layers"] <= 1:
            return self._create_single_rnn_cell(self.cfg["num_units"])
        else:
            MultiRNNCell([self._create_single_rnn_cell(self.cfg["num_units"]) for _ in range(self.cfg["num_layers"])])

    def _add_placeholders(self):
        raise NotImplementedError("To be implemented...")

    def _get_feed_dict(self, data):
        raise NotImplementedError("To be implemented...")

    def _build_embedding_op(self):
        raise NotImplementedError("To be implemented...")

    def _build_model_op(self):
        raise NotImplementedError("To be implemented...")

    def _build_loss_op(self):
        if self.cfg["use_crf"]:
            crf_loss, self.trans_params = crf_log_likelihood(self.logits, self.tags, self.seq_len)
            self.loss = tf.reduce_mean(-crf_loss)
        else:  # using softmax
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.tags)
            mask = tf.sequence_mask(self.seq_len)
            self.loss = tf.reduce_mean(tf.boolean_mask(losses, mask))
        tf.summary.scalar("loss", self.loss)

    def _build_train_op(self):
        with tf.variable_scope("train_step"):
            if self.cfg["optimizer"] == 'adagrad':
                optimizer = tf.train.AdagradOptimizer(learning_rate=self.lr)
            elif self.cfg["optimizer"] == 'sgd':
                optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.lr)
            elif self.cfg["optimizer"] == 'rmsprop':
                optimizer = tf.train.RMSPropOptimizer(learning_rate=self.lr)
            elif self.cfg["optimizer"] == 'adadelta':
                optimizer = tf.train.AdadeltaOptimizer(learning_rate=self.lr)
            else:  # default adam optimizer
                if self.cfg["optimizer"] != 'adam':
                    print('Unsupported optimizing method {}. Using default adam optimizer.'
                          .format(self.cfg["optimizer"]))
                optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
        if self.cfg["grad_clip"] is not None and self.cfg["grad_clip"] > 0:
            grads, vs = zip(*optimizer.compute_gradients(self.loss))
            grads, _ = tf.clip_by_global_norm(grads, self.cfg["grad_clip"])
            self.train_op = optimizer.apply_gradients(zip(grads, vs))
        else:
            self.train_op = optimizer.minimize(self.loss)

    def _predict_op(self, data):
        feed_dict = self._get_feed_dict(data)
        if self.cfg["use_crf"]:
            logits, trans_params, seq_len = self.sess.run([self.logits, self.trans_params, self.seq_len],
                                                          feed_dict=feed_dict)
            return self.viterbi_decode(logits, trans_params, seq_len)
        else:
            pred_logits = tf.cast(tf.argmax(self.logits, axis=-1), tf.int32)
            logits = self.sess.run(pred_logits, feed_dict=feed_dict)
            return logits

    def train_epoch(self, train_set, valid_data, epoch):
        raise NotImplementedError("To be implemented...")

    def train(self, train_set, valid_data, valid_set, test_set):
        self.logger.info("Start training...")
        best_f1, no_imprv_epoch, init_lr = -np.inf, 0, self.cfg["lr"]
        self._add_summary()
        for epoch in range(1, self.cfg["epochs"] + 1):
            self.logger.info('Epoch {}/{}:'.format(epoch, self.cfg["epochs"]))
            self.train_epoch(train_set, valid_data, epoch)  # train epochs
            self.evaluate(valid_set, "dev")
            score = self.evaluate(test_set, "test")
            if self.cfg["use_lr_decay"]:  # learning rate decay
                self.cfg["lr"] = max(init_lr / (1.0 + self.cfg["lr_decay"] * epoch), self.cfg["minimal_lr"])
            if score["FB1"] > best_f1:
                best_f1 = score["FB1"]
                no_imprv_epoch = 0
                self.save_session(epoch)
                self.logger.info(' -- new BEST score on test dataset: {:04.2f}'.format(best_f1))
            else:
                no_imprv_epoch += 1
                if no_imprv_epoch >= self.cfg["no_imprv_tolerance"]:
                    self.logger.info('early stop at {}th epoch without improvement, BEST score on testset: {:04.2f}'
                                     .format(epoch, best_f1))
                    break
        self.train_writer.close()
        self.test_writer.close()

    def evaluate(self, dataset, name):
        save_path = os.path.join(self.cfg["checkpoint_path"], "result.txt")
        predictions, groundtruth, words_list = list(), list(), list()
        for data in dataset:
            predicts = self._predict_op(data)
            for tags, preds, words, seq_len in zip(data["tags"], predicts, data["words"], data["seq_len"]):
                tags = [self.rev_tag_dict[x] for x in tags[:seq_len]]
                preds = [self.rev_tag_dict[x] for x in preds[:seq_len]]
                words = [self.rev_word_dict[x] for x in words[:seq_len]]
                predictions.append(preds)
                groundtruth.append(tags)
                words_list.append(words)
        ce = CoNLLeval()
        score = ce.conlleval(predictions, groundtruth, words_list, save_path)
        self.logger.info("{} dataset -- acc: {:04.2f}, pre: {:04.2f}, rec: {:04.2f}, FB1: {:04.2f}"
                         .format(name, score["accuracy"], score["precision"], score["recall"], score["FB1"]))
        return score

    def words_to_indices(self, words):
        """
        Convert input words into batchnized word/chars indices for inference
        :param words: input words
        :return: batchnized word indices
        """
        chars_idx = []
        for word in words:
            chars = [self.char_dict[char] if char in self.char_dict else self.char_dict[UNK] for char in word]
            chars_idx.append(chars)
        words = [word_convert(word) for word in words]
        words_idx = [self.word_dict[word] if word in self.word_dict else self.word_dict[UNK] for word in words]
        return process_batch_data([words_idx], [chars_idx])

    def inference(self, sentence):
        words = sentence.lstrip().rstrip().split(" ")
        data = self.words_to_indices(words)
        predicts = self._predict_op(data)
        predicts = [self.rev_tag_dict[idx] for idx in list(predicts[0])]
        results = align_data({"input": words, "output": predicts})
        return "{}\n{}".format(results["input"], results["output"])