import tensorflow as tf
import numpy as np
from numpy import nan
import codecs
import os
from models import BaseModel, AttentionCell, highway_network, BiRNN, DenselyConnectedBiRNN, multi_conv1d
from tensorflow.python.ops.rnn import dynamic_rnn
from tensorflow.contrib.crf import crf_log_likelihood
from utils import Progbar, pad_char_sequences
from utils.punct_prepro import PUNCTUATION_VOCABULARY, PUNCTUATION_MAPPING, END, UNK, EOS_TOKENS, SPACE


class SequenceLabelModel(BaseModel):
    def __init__(self, config):
        super(SequenceLabelModel, self).__init__(config)

    def _add_placeholders(self):
        self.words = tf.placeholder(tf.int32, shape=[None, None], name="words")  # shape = (batch_size, max_time)
        self.tags = tf.placeholder(tf.int32, shape=[None, None], name="tags")  # shape = (batch_size, max_time - 1)
        self.seq_len = tf.placeholder(tf.int32, shape=[None], name="seq_len")
        if self.cfg["use_chars"]:
            # shape = (batch_size, max_time, max_word_length)
            self.chars = tf.placeholder(tf.int32, shape=[None, None, None], name="chars")
            self.char_seq_len = tf.placeholder(tf.int32, shape=[None, None], name="char_seq_len")
        # hyper-parameters
        self.batch_size = tf.placeholder(tf.int32, name="batch_size")
        self.is_train = tf.placeholder(tf.bool, shape=[], name="is_train")
        self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
        self.drop_rate = tf.placeholder(tf.float32, name="dropout_rate")
        self.lr = tf.placeholder(tf.float32, name="learning_rate")

    def _get_feed_dict(self, batch, keep_prob=1.0, is_train=False, lr=None):
        feed_dict = {self.words: batch["words"], self.seq_len: batch["seq_len"], self.batch_size: batch["batch_size"]}
        if "tags" in batch:
            feed_dict[self.tags] = batch["tags"]
        if self.cfg["use_chars"]:
            feed_dict[self.chars] = batch["chars"]
            feed_dict[self.char_seq_len] = batch["char_seq_len"]
        feed_dict[self.keep_prob] = keep_prob
        feed_dict[self.drop_rate] = 1.0 - keep_prob
        feed_dict[self.is_train] = is_train
        if lr is not None:
            feed_dict[self.lr] = lr
        return feed_dict

    def _build_embedding_op(self):
        with tf.variable_scope("embeddings"):
            if not self.cfg["use_pretrained"]:
                self.word_embeddings = tf.get_variable(name="emb", dtype=tf.float32, trainable=True,
                                                       shape=[self.word_vocab_size, self.cfg["emb_dim"]])
            else:
                word_emb_1 = tf.Variable(np.load(self.cfg["pretrained_emb"])["embeddings"], name="word_emb_1",
                                         dtype=tf.float32, trainable=self.cfg["tuning_emb"])
                word_emb_2 = tf.get_variable(name="word_emb_2", shape=[3, self.cfg["emb_dim"]], dtype=tf.float32,
                                             trainable=True)  # For UNK, NUM and END
                self.word_embeddings = tf.concat([word_emb_1, word_emb_2], axis=0)
            word_emb = tf.nn.embedding_lookup(self.word_embeddings, self.words, name="word_emb")
            print("word embedding shape: {}".format(word_emb.get_shape().as_list()))
            if self.cfg["use_chars"]:
                self.char_embeddings = tf.get_variable(name="c_emb", dtype=tf.float32, trainable=True,
                                                       shape=[self.char_vocab_size, self.cfg["char_emb_dim"]])
                char_emb = tf.nn.embedding_lookup(self.char_embeddings, self.chars, name="chars_emb")
                # train char representation
                if self.cfg["char_represent_method"] == "rnn":
                    char_bi_rnn = BiRNN(self.cfg["char_num_units"], cell_type=self.cfg["cell_type"], scope="c_bi_rnn")
                    char_represent = char_bi_rnn(char_emb, self.char_seq_len, use_last_state=True)
                else:
                    char_represent = multi_conv1d(char_emb, self.cfg["filter_sizes"], self.cfg["channel_sizes"],
                                                  drop_rate=self.drop_rate,
                                                  is_train=self.is_train)
                print("chars representation shape: {}".format(char_represent.get_shape().as_list()))
                word_emb = tf.concat([word_emb, char_represent], axis=-1)
            if self.cfg["use_highway"]:
                self.word_emb = highway_network(word_emb, self.cfg["highway_layers"], use_bias=True, bias_init=0.0,
                                                keep_prob=self.keep_prob, is_train=self.is_train)
            else:
                self.word_emb = tf.layers.dropout(word_emb, rate=self.drop_rate, training=self.is_train)
            print("word and chars concatenation shape: {}".format(self.word_emb.get_shape().as_list()))

    def _build_model_op(self):
        with tf.variable_scope("densely_connected_bi_rnn"):
            dense_bi_rnn = DenselyConnectedBiRNN(self.cfg["num_layers"], self.cfg["num_units_list"],
                                                 cell_type=self.cfg["cell_type"])
            context = dense_bi_rnn(self.word_emb, seq_len=self.seq_len)
            print("densely connected bi_rnn output shape: {}".format(context.get_shape().as_list()))

        with tf.variable_scope("attention"):
            p_context = tf.layers.dense(context, units=2 * self.cfg["num_units_list"][-1], use_bias=True,
                                        bias_initializer=tf.constant_initializer(0.0))
            context = tf.transpose(context, [1, 0, 2])
            p_context = tf.transpose(p_context, [1, 0, 2])
            attn_cell = AttentionCell(self.cfg["num_units_list"][-1], context, p_context)
            attn_outs, _ = dynamic_rnn(attn_cell, context[1:, :, :], sequence_length=self.seq_len - 1, dtype=tf.float32,
                                       time_major=True)
            attn_outs = tf.transpose(attn_outs, [1, 0, 2])
            print("attention output shape: {}".format(attn_outs.get_shape().as_list()))

        with tf.variable_scope("project"):
            self.logits = tf.layers.dense(attn_outs, units=self.tag_vocab_size, use_bias=True,
                                          bias_initializer=tf.constant_initializer(0.0))
            print("logits shape: {}".format(self.logits.get_shape().as_list()))

    def _build_loss_op(self):
        if self.cfg["use_crf"]:
            crf_loss, self.trans_params = crf_log_likelihood(self.logits, self.tags, self.seq_len - 1)
            self.loss = tf.reduce_mean(-crf_loss)
        else:  # using softmax
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.tags)
            mask = tf.sequence_mask(self.seq_len)
            self.loss = tf.reduce_mean(tf.boolean_mask(losses, mask))
        if self.cfg["l2_reg"] is not None and self.cfg["l2_reg"] > 0.0:  # l2 regularization
            l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if "bias" not in v.name])
            self.loss += self.cfg["l2_reg"] * l2_loss
        tf.summary.scalar("loss", self.loss)

    def _predict_op(self, data):
        feed_dict = self._get_feed_dict(data)
        if self.cfg["use_crf"]:
            logits, trans_params = self.sess.run([self.logits, self.trans_params], feed_dict=feed_dict)
            return self.viterbi_decode(logits, trans_params, data["seq_len"] - 1)
        else:
            pred_logits = tf.cast(tf.argmax(self.logits, axis=-1), tf.int32)
            logits = self.sess.run(pred_logits, feed_dict=feed_dict)
            return logits

    def train_epoch(self, train_set, valid_data, epoch):
        num_batches = len(train_set)
        prog = Progbar(target=num_batches)
        total_cost, total_samples = 0, 0
        for i, batch in enumerate(train_set):
            feed_dict = self._get_feed_dict(batch, is_train=True, keep_prob=self.cfg["keep_prob"], lr=self.cfg["lr"])
            _, train_loss, summary = self.sess.run([self.train_op, self.loss, self.summary], feed_dict=feed_dict)
            cur_step = (epoch - 1) * num_batches + (i + 1)
            total_cost += train_loss
            total_samples += np.array(batch["words"]).shape[0]
            prog.update(i + 1, [("Global Step", int(cur_step)), ("Train Loss", train_loss),
                                ("Perplexity", np.exp(total_cost / total_samples))])
            self.train_writer.add_summary(summary, cur_step)
            if i % 100 == 0:
                valid_feed_dict = self._get_feed_dict(valid_data)
                valid_summary = self.sess.run(self.summary, feed_dict=valid_feed_dict)
                self.test_writer.add_summary(valid_summary, cur_step)

    def train(self, train_set, valid_data, valid_text, test_texts):  # test_texts: [ref, asr]
        self.logger.info("Start training...")
        best_f1, no_imprv_epoch = -np.inf, 0
        self._add_summary()
        for epoch in range(1, self.cfg["epochs"] + 1):
            self.logger.info("Epoch {}/{}:".format(epoch, self.cfg["epochs"]))
            self.train_epoch(train_set, valid_data, epoch)
            # self.evaluate(valid_text)
            ref_f1 = self.evaluate_punct(test_texts[0])["F1"] * 100.0  # use ref to compute best F1
            asr_f1 = self.evaluate_punct(test_texts[1])["F1"] * 100.0
            if ref_f1 >= best_f1:
                best_f1 = ref_f1
                no_imprv_epoch = 0
                self.save_session(epoch)
                self.logger.info(" -- new BEST score on ref dataset: {:04.2f}, on asr dataset: {:04.2f}"
                                 .format(best_f1, asr_f1))
            else:
                no_imprv_epoch += 1
                if no_imprv_epoch >= self.cfg["no_imprv_tolerance"]:
                    self.logger.info("early stop at {}th epoch without improvement, BEST score on ref dataset: {:04.2f}"
                                     .format(epoch, best_f1))
                    break
        self.train_writer.close()
        self.test_writer.close()

    def evaluate_punct(self, file):
        save_path = os.path.join(self.cfg["checkpoint_path"], "result.txt")
        with codecs.open(file, mode="r", encoding="utf-8") as f:
            text = f.read().split()
        text = [w for w in text if w not in self.tag_dict and w not in PUNCTUATION_MAPPING] + [END]
        index = 0
        with codecs.open(save_path, mode="w", encoding="utf-8") as f_out:
            while True:
                subseq = text[index: index + self.cfg["max_sequence_len"]]
                if len(subseq) == 0:
                    break
                # create feed data
                cvrt_seq = np.array([[self.word_dict.get(w, self.word_dict[UNK]) for w in subseq]], dtype=np.int32)
                seq_len = np.array([len(v) for v in cvrt_seq], dtype=np.int32)
                cvrt_seq_chars = []
                for word in subseq:
                    chars = [self.char_dict.get(c, self.char_dict[UNK]) for c in word]
                    cvrt_seq_chars.append(chars)
                cvrt_seq_chars, char_seq_len = pad_char_sequences([cvrt_seq_chars])
                cvrt_seq_chars = np.array(cvrt_seq_chars, dtype=np.int32)
                char_seq_len = np.array(char_seq_len, dtype=np.int32)
                data = {"words": cvrt_seq, "seq_len": seq_len, "chars": cvrt_seq_chars, "char_seq_len": char_seq_len,
                        "batch_size": 1}
                # predict
                predicts = self._predict_op(data)
                # write to file
                f_out.write(subseq[0])
                last_eos_idx = 0
                punctuations = []
                for preds_t in predicts[0]:
                    punctuation = self.rev_tag_dict[preds_t]
                    punctuations.append(punctuation)
                    if punctuation in EOS_TOKENS:
                        last_eos_idx = len(punctuations)
                if subseq[-1] == END:
                    step = len(subseq) - 1
                elif last_eos_idx != 0:
                    step = last_eos_idx
                else:
                    step = len(subseq) - 1
                for j in range(step):
                    f_out.write(" " + punctuations[j] + " " if punctuations[j] != SPACE else " ")
                    if j < step - 1:
                        f_out.write(subseq[1 + j])
                if subseq[-1] == END:
                    break
                index += step
        out_str, f1, err, ser = self.compute_score(file, save_path)
        score = {"F1": f1, "ERR": err, "SER": ser}
        self.logger.info("\nEvaluate on {}:\n{}\n".format(file, out_str))
        try:  # delete output file after compute scores
            os.remove(save_path)
        except OSError:
            pass
        return score

    def inference(self, sentence):
        pass  # TODO

    @staticmethod
    def compute_score(target_path, predicted_path):
        """Computes and prints the overall classification error and precision, recall, F-score over punctuations."""
        mappings, counter, t_i, p_i = {}, 0, 0, 0
        total_correct, correct, substitutions, deletions, insertions = 0, 0.0, 0.0, 0.0, 0.0
        true_pos, false_pos, false_neg = {}, {}, {}
        with codecs.open(target_path, "r", "utf-8") as f_target, codecs.open(predicted_path, "r", "utf-8") as f_predict:
            target_stream = f_target.read().split()
            predict_stream = f_predict.read().split()
            while True:
                if PUNCTUATION_MAPPING.get(target_stream[t_i], target_stream[t_i]) in PUNCTUATION_VOCABULARY:
                    # skip multiple consecutive punctuations
                    target_punct = " "
                    while PUNCTUATION_MAPPING.get(target_stream[t_i], target_stream[t_i]) in PUNCTUATION_VOCABULARY:
                        target_punct = PUNCTUATION_MAPPING.get(target_stream[t_i], target_stream[t_i])
                        target_punct = mappings.get(target_punct, target_punct)
                        t_i += 1
                else:
                    target_punct = " "
                if predict_stream[p_i] in PUNCTUATION_VOCABULARY:
                    predicted_punct = mappings.get(predict_stream[p_i], predict_stream[p_i])
                    p_i += 1
                else:
                    predicted_punct = " "
                is_correct = target_punct == predicted_punct
                counter += 1
                total_correct += is_correct
                if predicted_punct == " " and target_punct != " ":
                    deletions += 1
                elif predicted_punct != " " and target_punct == " ":
                    insertions += 1
                elif predicted_punct != " " and target_punct != " " and predicted_punct == target_punct:
                    correct += 1
                elif predicted_punct != " " and target_punct != " " and predicted_punct != target_punct:
                    substitutions += 1
                true_pos[target_punct] = true_pos.get(target_punct, 0.0) + float(is_correct)
                false_pos[predicted_punct] = false_pos.get(predicted_punct, 0.) + float(not is_correct)
                false_neg[target_punct] = false_neg.get(target_punct, 0.) + float(not is_correct)
                assert target_stream[t_i] == predict_stream[p_i] or predict_stream[p_i] == "<unk>", \
                    "File: %s \nError: %s (%s) != %s (%s) \nTarget context: %s \nPredicted context: %s" % \
                    (target_path, target_stream[t_i], t_i, predict_stream[p_i], p_i,
                     " ".join(target_stream[t_i - 2:t_i + 2]), " ".join(predict_stream[p_i - 2:p_i + 2]))
                t_i += 1
                p_i += 1
                if t_i >= len(target_stream) - 1 and p_i >= len(predict_stream) - 1:
                    break
        overall_tp, overall_fp, overall_fn = 0.0, 0.0, 0.0
        out_str = "-" * 46 + "\n"
        out_str += "{:<16} {:<9} {:<9} {:<9}\n".format("PUNCTUATION", "PRECISION", "RECALL", "F-SCORE")
        for p in PUNCTUATION_VOCABULARY:
            if p == SPACE:
                continue
            overall_tp += true_pos.get(p, 0.0)
            overall_fp += false_pos.get(p, 0.0)
            overall_fn += false_neg.get(p, 0.0)
            punctuation = p
            precision = (true_pos.get(p, 0.0) / (true_pos.get(p, 0.0) + false_pos[p])) if p in false_pos else nan
            recall = (true_pos.get(p, 0.0) / (true_pos.get(p, 0.0) + false_neg[p])) if p in false_neg else nan
            f_score = (2. * precision * recall / (precision + recall)) if (precision + recall) > 0 else nan
            out_str += u"{:<16} {:<9} {:<9} {:<9}\n".format(punctuation, "{:.2f}".format(precision * 100),
                                                            "{:.2f}".format(recall * 100),
                                                            "{:.2f}".format(f_score * 100))
        out_str += "-" * 46 + "\n"
        pre = overall_tp / (overall_tp + overall_fp) if overall_fp else nan
        rec = overall_tp / (overall_tp + overall_fn) if overall_fn else nan
        f1 = (2. * pre * rec) / (pre + rec) if (pre + rec) else nan
        out_str += "{:<16} {:<9} {:<9} {:<9}\n".format("Overall", "{:.2f}".format(pre * 100),
                                                       "{:.2f}".format(rec * 100), "{:.2f}".format(f1 * 100))
        err = round((100.0 - float(total_correct) / float(counter - 1) * 100.0), 2)
        ser = round((substitutions + deletions + insertions) / (correct + substitutions + deletions) * 100, 1)
        out_str += "ERR: %s%%\n" % err
        out_str += "SER: %s%%" % ser
        return out_str, f1, err, ser