import tensorflow as tf
from copy import copy

from tensor2tensor.utils import beam_search

from .transformer_decoder import TransformerDecoder

from . import modeling
from .top_utils import (TopLayer, gather_indexes,
                        make_cudnngru, create_seq_smooth_label,
                        dense_layer)


class SequenceLabel(TopLayer):
    '''Top model for sequence labeling.
    It's a dense net with body output features as input with following support.

    crf: Conditional Random Field. Take logits(output of dense layer) as input
    hidden_gru: Take body features as input and apply rnn on it.
    label_smoothing: Hard label smoothing. Random replace label by some prob.
    '''

    def make_batch_loss(self, logits, seq_labels, seq_length, crf_transition_param):
        if self.params.crf:
            with tf.variable_scope('CRF'):
                log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(
                    logits, seq_labels, seq_length,
                    transition_params=crf_transition_param)
                batch_loss = -log_likelihood
        else:
            # inconsistent shape might be introduced to labels
            # so we need to do some padding to make sure that
            # seq_labels has the same sequence length as logits
            pad_len = tf.shape(logits)[1] - tf.shape(seq_labels)[1]

            # top, bottom, left, right
            pad_tensor = [[0, 0], [0, pad_len]]
            seq_labels = tf.pad(seq_labels, paddings=pad_tensor)

            batch_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=logits, labels=seq_labels), axis=1)

        if self.params.uncertain_weight_loss:
            batch_loss = self.uncertainty_weighted_loss(batch_loss)
        return batch_loss

    def __call__(self, features, hidden_feature, mode, problem_name, mask=None):
        hidden_feature = hidden_feature['seq']
        scope_name = self.params.share_top[problem_name]
        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_feature = tf.nn.dropout(
                hidden_feature,
                keep_prob=self.params.dropout_keep_prob)

        if mask is None:
            num_classes = self.params.num_classes[problem_name]
        else:
            num_classes = mask.shape[0]

        # make hidden model
        hidden_feature = self.make_hidden_model(
            features, hidden_feature, mode, True)
        logits = dense_layer(num_classes, hidden_feature, mode, 1.0, None)
        self.logits = logits
        if mask is not None:
            logits = logits*mask

        # CRF transition param
        crf_transition_param = tf.get_variable(
            'crf_transition', shape=[num_classes, num_classes])

        # sequence_weight = tf.cast(features["input_mask"], tf.float32)
        seq_length = tf.reduce_sum(features["input_mask"], axis=-1)

        if mode == tf.estimator.ModeKeys.TRAIN:
            seq_labels = features['%s_label_ids' % problem_name]
            seq_labels = create_seq_smooth_label(
                self.params, seq_labels, num_classes)
            batch_loss = self.make_batch_loss(
                logits, seq_labels, seq_length, crf_transition_param)
            self.loss = self.create_loss(
                batch_loss, features['%s_loss_multiplier' % problem_name])
            # If a batch does not contain input instances from the current problem, the loss multiplier will be empty
            # and loss will be NaN. Replacing NaN with 0 fixes the problem.
            self.loss = tf.where(tf.math.is_nan(self.loss),
                                 tf.zeros_like(self.loss), self.loss)
            return self.loss

        elif mode == tf.estimator.ModeKeys.EVAL:
            seq_labels = features['%s_label_ids' % problem_name]
            batch_loss = self.make_batch_loss(
                logits, seq_labels, seq_length, crf_transition_param)

            seq_loss = tf.reduce_mean(batch_loss)

            return self.eval_metric_fn(
                features, logits, seq_loss, problem_name, features['input_mask'], pad_labels_to_logits=True)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            if self.params.crf:
                viterbi_sequence, viterbi_score = tf.contrib.crf.crf_decode(
                    logits, crf_transition_param, seq_length)
                self.prob = tf.identity(
                    viterbi_sequence, name='%s_predict' % scope_name)
            else:
                self.prob = tf.nn.softmax(
                    logits, name='%s_predict' % scope_name)

            return self.prob


class Classification(TopLayer):
    '''Top model for classification.
    It's a dense net with body output features as input with following support.

    label_smoothing: Soft label smoothing.
    '''

    def create_batch_loss(self, labels, logits,  num_classes):
        if self.params.label_smoothing > 0:
            one_hot_labels = tf.one_hot(labels, depth=num_classes)
            batch_loss = tf.losses.softmax_cross_entropy(
                one_hot_labels, logits,
                label_smoothing=self.params.label_smoothing)
        else:
            batch_loss = tf.losses.sparse_softmax_cross_entropy(
                labels, logits)

        if self.params.uncertain_weight_loss:
            batch_loss = self.uncertainty_weighted_loss(batch_loss)
        return batch_loss

    def __call__(self, features, hidden_feature, mode, problem_name, mask=None):
        hidden_feature = hidden_feature['pooled']
        scope_name = self.params.share_top[problem_name]
        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_feature = tf.nn.dropout(
                hidden_feature,
                keep_prob=self.params.dropout_keep_prob)

        if mask is None:
            num_classes = self.params.num_classes.get(problem_name, 2)
        else:
            num_classes = mask.shape[0]
        # make hidden model
        hidden_feature = self.make_hidden_model(
            features, hidden_feature, mode, 'pooled')
        logits = dense_layer(num_classes, hidden_feature, mode, 1.0, None)
        self.logits = logits
        if mask is not None:
            logits = logits*mask
        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = features['%s_label_ids' % problem_name]
            batch_loss = self.create_batch_loss(labels, logits, num_classes)
            self.loss = self.create_loss(
                batch_loss, features['%s_loss_multiplier' % problem_name])
            # If a batch does not contain input instances from the current problem, the loss multiplier will be empty
            # and loss will be NaN. Replacing NaN with 0 fixes the problem.
            self.loss = tf.where(tf.math.is_nan(self.loss),
                                 tf.zeros_like(self.loss), self.loss)
            return self.loss
        elif mode == tf.estimator.ModeKeys.EVAL:
            labels = features['%s_label_ids' % problem_name]
            batch_loss = self.create_batch_loss(labels, logits, num_classes)
            # multiply with loss multiplier to make some loss as zero
            loss = tf.reduce_mean(batch_loss)

            return self.eval_metric_fn(
                features, logits, loss, problem_name, pad_labels_to_logits=False)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            prob = tf.nn.softmax(logits)
            self.prob = tf.identity(prob, name='%s_predict' % scope_name)
            return self.prob


class MaskLM(TopLayer):
    '''Top model for mask language model.
    It's a dense net with body output features as input.
    Major logic is from original bert code
    '''

    def __call__(self, features, hidden_feature, mode, problem_name):
        """Get loss and log probs for the masked LM.

        DO NOT CHANGE THE VARAIBLE SCOPE.
        """
        seq_hidden_feature = hidden_feature['seq']
        positions = features['masked_lm_positions']
        input_tensor = gather_indexes(seq_hidden_feature, positions)
        output_weights = hidden_feature['embed_table']
        label_ids = features['masked_lm_ids']
        label_weights = features['masked_lm_weights']

        with tf.variable_scope("cls/predictions"):
            # We apply one more non-linear transformation before the output layer.
            # This matrix is not used after pre-training.
            with tf.variable_scope("transform"):
                input_tensor = tf.layers.dense(
                    input_tensor,
                    units=self.params.mask_lm_hidden_size,
                    activation=modeling.get_activation(
                        self.params.mask_lm_hidden_act),
                    kernel_initializer=modeling.create_initializer(
                        self.params.mask_lm_initializer_range))
                input_tensor = modeling.layer_norm(input_tensor)

            # The output weights are the same as the input embeddings, but there is
            # an output-only bias for each token.
            output_bias = tf.get_variable(
                "output_bias",
                shape=[self.params.vocab_size],
                initializer=tf.zeros_initializer())

            logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
            logits = tf.nn.bias_add(logits, output_bias)
            self.logits = logits
            log_probs = tf.nn.log_softmax(logits, axis=-1)

            if mode == tf.estimator.ModeKeys.PREDICT:
                self.prob = log_probs
                return self.prob

            else:

                label_ids = tf.reshape(label_ids, [-1])
                label_weights = tf.reshape(label_weights, [-1])

                one_hot_labels = tf.one_hot(
                    label_ids, depth=self.params.vocab_size, dtype=tf.float32)

                # The `positions` tensor might be zero-padded (if the sequence is too
                # short to have the maximum number of predictions). The `label_weights`
                # tensor has a value of 1.0 for every real prediction and 0.0 for the
                # padding predictions.
                per_example_loss = - \
                    tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
                label_weights = tf.cast(label_weights, tf.float32)
                numerator = tf.reduce_sum(label_weights * per_example_loss)
                denominator = tf.reduce_sum(label_weights) + 1e-5
                loss = numerator / denominator

                if mode == tf.estimator.ModeKeys.TRAIN:
                    self.loss = loss
                    return self.loss

                else:
                    def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                                  masked_lm_weights):
                        """Computes the loss and accuracy of the model."""
                        masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
                                                         [-1, masked_lm_log_probs.shape[-1]])
                        masked_lm_predictions = tf.argmax(
                            masked_lm_log_probs, axis=-1, output_type=tf.int32)
                        masked_lm_example_loss = tf.reshape(
                            masked_lm_example_loss, [-1])
                        masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                        masked_lm_weights = tf.reshape(
                            masked_lm_weights, [-1])
                        masked_lm_accuracy = tf.metrics.accuracy(
                            labels=masked_lm_ids,
                            predictions=masked_lm_predictions,
                            weights=masked_lm_weights)
                        masked_lm_mean_loss = tf.metrics.mean(
                            values=masked_lm_example_loss, weights=masked_lm_weights)

                        return {
                            "masked_lm_accuracy": masked_lm_accuracy,
                            "masked_lm_loss": masked_lm_mean_loss,
                        }
                    eval_metrics = (metric_fn(
                        per_example_loss, log_probs, label_ids,
                        label_weights), loss)

                    self.eval_metrics = eval_metrics
                    return self.eval_metrics


class PreTrain(TopLayer):
    '''Top model for pretrain.
    It's MaskLM + Classification(next sentence prediction)
    '''

    def __call__(self, features, hidden_feature, mode, problem_name):
        mask_lm_top = MaskLM(self.params)
        self.params.share_top['next_sentence'] = 'next_sentence'
        mask_lm_top_result = mask_lm_top(
            features, hidden_feature, mode, problem_name)
        with tf.variable_scope('next_sentence', reuse=tf.AUTO_REUSE):
            cls = Classification(self.params)
            features['next_sentence_loss_multiplier'] = 1
            next_sentence_top_result = cls(
                features, hidden_feature, mode, 'next_sentence')
        if mode == tf.estimator.ModeKeys.TRAIN:
            self.loss = mask_lm_top_result+next_sentence_top_result
            return self.loss
        elif mode == tf.estimator.ModeKeys.EVAL:
            mask_lm_eval_dict, mask_lm_loss = mask_lm_top_result
            next_sentence_eval_dict, next_sentence_loss = next_sentence_top_result
            mask_lm_eval_dict.update(next_sentence_eval_dict)
            self.eval_metrics = (mask_lm_eval_dict,
                                 mask_lm_loss+next_sentence_loss)
            return self.eval_metrics
        elif mode == tf.estimator.ModeKeys.PREDICT:
            self.prob = mask_lm_top_result
            return self.prob


class Seq2Seq(TopLayer):
    '''Top model for seq2seq problem.
    This is basically a decoder of encoder-decoder framework.
    Here uses transformer decoder architecture with beam search support.
    '''

    def _get_symbol_to_logit_fn(self,
                                max_seq_len,
                                embedding_table,
                                token_type_ids,
                                decoder,
                                num_classes,
                                encoder_output,
                                input_mask,
                                params):
        decoder_self_attention_mask = decoder.get_decoder_self_attention_mask(
            max_seq_len)

        batch_size = tf.shape(encoder_output)[0]
        max_seq_len = tf.shape(encoder_output)[1]

        encoder_output = tf.expand_dims(encoder_output, axis=1)
        tile_dims = [1] * encoder_output.shape.ndims
        tile_dims[1] = params.beam_size

        encoder_output = tf.tile(encoder_output, tile_dims)
        encoder_output = tf.reshape(encoder_output,
                                    [-1, max_seq_len, params.bert_config.hidden_size])

        def symbols_to_logits_fn(ids, i, cache):

            decoder_inputs = tf.nn.embedding_lookup(
                embedding_table, ids)

            decoder_inputs = modeling.embedding_postprocessor(
                input_tensor=decoder_inputs,
                use_token_type=False,
                token_type_ids=token_type_ids,
                token_type_vocab_size=params.bert_config.type_vocab_size,
                token_type_embedding_name="token_type_embeddings",
                use_position_embeddings=True,
                position_embedding_name="position_embeddings",
                initializer_range=params.bert_config.initializer_range,
                max_position_embeddings=params.bert_config.max_position_embeddings,
                dropout_prob=self.params.bert_config.hidden_dropout_prob)
            final_decoder_input = decoder_inputs[:, -1:, :]
            # final_decoder_input = decoder_inputs
            self_attention_mask = decoder_self_attention_mask[:, i:i+1, :i+1]

            logits = decoder.decode(
                decoder_inputs=final_decoder_input,
                encoder_output=encoder_output,
                input_mask=input_mask,
                decoder_self_attention_mask=self_attention_mask,
                cache=cache,
                num_classes=num_classes,
                do_return_all_layers=False)

            return logits, cache
        return symbols_to_logits_fn

    def beam_search_decode(self, features, hidden_feature, mode, problem_name):
        # prepare inputs to attention
        key = 'ori_seq' if self.params.label_transfer else 'seq'
        encoder_outputs = hidden_feature[key]
        max_seq_len = self.params.max_seq_len
        embedding_table = hidden_feature['embed_table']
        token_type_ids = features['segment_ids']
        num_classes = self.params.num_classes[problem_name]
        batch_size = modeling.get_shape_list(
            encoder_outputs, expected_rank=3)[0]
        hidden_size = self.params.bert_config.hidden_size

        if self.params.problem_type[problem_name] == 'seq2seq_text':
            embedding_table = hidden_feature['embed_table']
        else:
            embedding_table = tf.get_variable(
                'tag_embed_table',
                shape=[num_classes, hidden_size])

        symbol_to_logit_fn = self._get_symbol_to_logit_fn(
            max_seq_len=max_seq_len,
            embedding_table=embedding_table,
            token_type_ids=token_type_ids,
            decoder=self.decoder,
            num_classes=num_classes,
            encoder_output=encoder_outputs,
            input_mask=features['input_mask'],
            params=self.params
        )

        # create cache for fast decode
        cache = {
            str(layer): {
                "key_layer": tf.zeros([batch_size, 0, hidden_size]),
                "value_layer": tf.zeros([batch_size, 0, hidden_size]),
            } for layer in range(self.params.decoder_num_hidden_layers)}
        # cache['encoder_outputs'] = encoder_outputs
        # cache['encoder_decoder_attention_mask'] = features['input_mask']
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        decode_ids, _, _ = beam_search.beam_search(
            symbols_to_logits_fn=symbol_to_logit_fn,
            initial_ids=initial_ids,
            states=cache,
            vocab_size=self.params.num_classes[problem_name],
            beam_size=self.params.beam_size,
            alpha=self.params.beam_search_alpha,
            decode_length=self.params.decode_max_seq_len,
            eos_id=self.params.eos_id[problem_name])
        # Get the top sequence for each batch element
        top_decoded_ids = decode_ids[:, 0, 1:]
        self.prob = top_decoded_ids
        return self.prob

    def __call__(self, features, hidden_feature, mode, problem_name):
        self.decoder = TransformerDecoder(self.params)
        scope_name = self.params.share_top[problem_name]

        if mode != tf.estimator.ModeKeys.PREDICT:
            labels = features['%s_label_ids' % problem_name]

            logits = self.decoder.train_eval(
                features, hidden_feature, mode, problem_name)

            with tf.name_scope("shift_targets"):
                # Shift targets to the right, and remove the last element
                shift_labels = tf.pad(
                    labels, [[0, 0], [0, 1]])[:, 1:]
            batch_loss = tf.losses.sparse_softmax_cross_entropy(
                shift_labels, logits)
            loss = self.create_loss(
                batch_loss, features['%s_loss_multiplier' % problem_name])
            # If a batch does not contain input instances from the current problem, the loss multiplier will be empty
            # and loss will be NaN. Replacing NaN with 0 fixes the problem.
            loss = tf.where(tf.math.is_nan(loss), tf.zeros_like(loss), loss)
            self.loss = loss

            if mode == tf.estimator.ModeKeys.TRAIN:
                return self.loss
            else:
                return self.eval_metric_fn(
                    features, logits, loss, problem_name, features['%s_mask' % problem_name])

        else:
            self.pred = tf.identity(self.beam_search_decode(
                features, hidden_feature, mode, problem_name),
                name='%s_predict' % scope_name)
            return self.pred


class MultiLabelClassification(TopLayer):
    '''Top model for multi-class classification.
    It's a dense net with body output features as input with following support.

    label_smoothing: Soft label smoothing.
    '''

    def create_batch_loss(self, labels, logits,  num_classes):
        labels = tf.cast(labels, tf.float32)
        batch_label_loss = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=labels, logits=logits)

        batch_loss = tf.reduce_sum(batch_label_loss, axis=1)

        if self.params.uncertain_weight_loss:
            batch_loss = self.uncertainty_weighted_loss(batch_loss)
        return batch_loss

    def __call__(self, features, hidden_feature, mode, problem_name, mask=None):
        hidden_feature = hidden_feature['pooled']
        scope_name = self.params.share_top[problem_name]
        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_feature = tf.nn.dropout(
                hidden_feature,
                keep_prob=self.params.dropout_keep_prob)

        if mask is None:
            num_classes = self.params.num_classes[problem_name]
        else:
            num_classes = mask.shape[0]
        # make hidden model
        hidden_feature = self.make_hidden_model(
            features, hidden_feature, mode, 'pooled')
        logits = dense_layer(num_classes, hidden_feature, mode, 1.0, None)
        self.logits = logits
        if mask is not None:
            logits = logits*mask
        if mode == tf.estimator.ModeKeys.TRAIN:
            labels = features['%s_label_ids' % problem_name]
            batch_loss = self.create_batch_loss(labels, logits, num_classes)
            self.loss = self.create_loss(
                batch_loss, features['%s_loss_multiplier' % problem_name])
            # If a batch does not contain input instances from the current problem, the loss multiplier will be empty
            # and loss will be NaN. Replacing NaN with 0 fixes the problem.
            self.loss = tf.where(tf.math.is_nan(self.loss),
                                 tf.zeros_like(self.loss), self.loss)
            return self.loss
        elif mode == tf.estimator.ModeKeys.EVAL:
            labels = features['%s_label_ids' % problem_name]
            batch_loss = self.create_batch_loss(labels, logits, num_classes)
            # multiply with loss multiplier to make some loss as zero
            loss = tf.reduce_mean(batch_loss)
            prob = tf.nn.sigmoid(logits)
            prob = tf.round(prob)
            prob = tf.expand_dims(prob, -1)
            return self.eval_metric_fn(
                features, prob, loss, problem_name)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            prob = tf.nn.sigmoid(logits)
            self.prob = tf.identity(prob, name='%s_predict' % scope_name)
            return self.prob