import tensorflow as tf
import math
from tensor2tensor.utils import beam_search

from . import modeling


class TransformerDecoder(object):
    def __init__(self, params):
        self.params = params

    def get_decoder_self_attention_mask(self, length):
        """Calculate bias for decoder that maintains model's autoregressive property.
        Creates a tensor that masks out locations that correspond to illegal
        connections, so prediction at position i cannot draw information from future
        positions.
        Args:
            length: int length of sequences in batch.
        Returns:
            float tensor of shape [1, 1, length, length]
        """
        with tf.name_scope("decoder_self_attention_mask"):
            valid_locs = tf.matrix_band_part(tf.ones([length, length]), -1, 0)
            valid_locs = tf.reshape(valid_locs, [1, length, length])
        return valid_locs

    def decode(
            self,
            decoder_inputs,
            encoder_output,
            input_mask,
            decoder_self_attention_mask,
            cache,
            num_classes,
            do_return_all_layers,
            enc_dec_attention_mask=None,
            add_self_attention=True,
            add_enc_dec_attention=True):
        input_tensor = decoder_inputs
        num_hidden_layers = self.params.decoder_num_hidden_layers
        hidden_size = self.params.bert_config.hidden_size
        num_attention_heads = self.params.bert_config.num_attention_heads
        initializer_range = self.params.bert_config.initializer_range
        attention_probs_dropout_prob = self.params.bert_config.attention_probs_dropout_prob

        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))

        attention_head_size = int(hidden_size / num_attention_heads)
        encode_shape = modeling.get_shape_list(
            encoder_output, expected_rank=3)
        batch_size = encode_shape[0]
        encode_seq_length = encode_shape[1]
        input_width = encode_shape[2]

        input_shape = modeling.get_shape_list(input_tensor, expected_rank=3)
        decode_seq_length = input_shape[1]

        # create encoder-decoder attention mask
        attention_mask_shape = modeling.get_shape_list(
            input_mask, expected_rank=2)[1]

        # batch_size*beam_size
        if enc_dec_attention_mask is None:
            input_batch_size = modeling.get_shape_list(
                decoder_inputs, expected_rank=3)[0]
            input_mask = tf.broadcast_to(
                input_mask, [input_batch_size, attention_mask_shape])
            attention_mask = modeling.create_attention_mask_from_input_mask(
                decoder_inputs, input_mask
            )
        else:
            attention_mask = enc_dec_attention_mask

        # The Transformer performs sum residuals on all layers so the input needs
        # to be the same as the hidden size.
        if input_width != hidden_size:
            raise ValueError("The width of the input tensor (%d) != hidden size (%d)" %
                             (input_width, hidden_size))

        prev_output = modeling.reshape_to_matrix(input_tensor)

        all_layer_outputs = []
        for layer_idx in range(num_hidden_layers):
            with tf.variable_scope("layer_%d" % layer_idx):
                layer_input = prev_output

                if cache is not None:
                    layer_cache = cache[str(layer_idx)]
                    if layer_idx == 0:
                        layer_input = tf.expand_dims(
                            layer_input, axis=1)
                        # update batch_size to batch_size*beam_size
                        batch_size = modeling.get_shape_list(
                            layer_input, expected_rank=3)[0]
                else:
                    layer_cache = None

                with tf.variable_scope("attention"):
                    attention_heads = []
                    if add_self_attention:
                        with tf.variable_scope("self"):
                            attention_head = attention_layer_with_cache(
                                from_tensor=layer_input,
                                to_tensor=layer_input,
                                attention_mask=decoder_self_attention_mask,
                                num_attention_heads=num_attention_heads,
                                size_per_head=attention_head_size,
                                attention_probs_dropout_prob=attention_probs_dropout_prob,
                                initializer_range=initializer_range,
                                do_return_2d_tensor=False,
                                batch_size=batch_size,
                                from_seq_length=decode_seq_length,
                                to_seq_length=decode_seq_length,
                                cache=layer_cache)
                            attention_heads.append(attention_head)

                            self_attention_output = None
                            if len(attention_heads) == 1:
                                self_attention_output = attention_heads[0]
                            else:
                                # In the case where we have other sequences, we just concatenate
                                # them to the self-attention head before the projection.
                                self_attention_output = tf.concat(
                                    attention_heads, axis=-1)
                        if cache is not None:
                            self_attention_output = tf.reshape(
                                self_attention_output, [batch_size, -1, hidden_size])
                    else:
                        self_attention_output = tf.reshape(
                            layer_input, [batch_size, -1, hidden_size])

                    if add_enc_dec_attention:
                        with tf.variable_scope('enc_dec_attention'):
                            attention_heads = []
                            attention_head = attention_layer_with_cache(
                                from_tensor=self_attention_output,
                                to_tensor=encoder_output,
                                attention_mask=attention_mask,
                                num_attention_heads=num_attention_heads,
                                size_per_head=attention_head_size,
                                attention_probs_dropout_prob=attention_probs_dropout_prob,
                                initializer_range=initializer_range,
                                do_return_2d_tensor=True,
                                batch_size=batch_size,
                                from_seq_length=decode_seq_length,
                                to_seq_length=encode_seq_length,
                                cache=None)
                            attention_heads.append(attention_head)

                            attention_output = None
                            if len(attention_heads) == 1:
                                attention_output = attention_heads[0]
                            else:
                                # In the case where we have other sequences, we just concatenate
                                # them to the self-attention head before the projection.
                                attention_output = tf.concat(
                                    attention_heads, axis=-1)
                        if cache is not None:
                            attention_output = tf.reshape(
                                attention_output, [batch_size, -1, hidden_size])
                    else:
                        attention_output = tf.reshape(
                            self_attention_output, [-1, hidden_size])

                    # Run a linear projection of `hidden_size` then add a residual
                    # with `layer_input`.
                    with tf.variable_scope("output"):
                        attention_output = tf.layers.dense(
                            attention_output,
                            hidden_size,
                            kernel_initializer=modeling.create_initializer(
                                initializer_range))
                        attention_output = modeling.dropout(
                            attention_output,
                            self.params.bert_config.hidden_dropout_prob)
                        attention_output = modeling.layer_norm(
                            attention_output + layer_input)

                # The activation is only applied to the "intermediate" hidden layer.
                with tf.variable_scope("intermediate"):
                    intermediate_output = tf.layers.dense(
                        attention_output,
                        self.params.bert_config.intermediate_size,
                        activation=modeling.gelu,
                        kernel_initializer=modeling.create_initializer(
                            initializer_range))

                # Down-project back to `hidden_size` then add the residual.
                with tf.variable_scope("output"):
                    layer_output = tf.layers.dense(
                        intermediate_output,
                        hidden_size,
                        kernel_initializer=modeling.create_initializer(
                            initializer_range))
                    layer_output = modeling.dropout(
                        layer_output,
                        self.params.bert_config.hidden_dropout_prob)
                    layer_output = modeling.layer_norm(
                        layer_output + attention_output)
                    prev_output = layer_output
                    all_layer_outputs.append(layer_output)

        if do_return_all_layers:
            final_outputs = []
            for layer_output in all_layer_outputs:
                final_output = modeling.reshape_from_matrix(
                    layer_output, input_shape)
                final_outputs.append(final_output)
            return final_outputs
        else:
            if cache is None:
                final_output = modeling.reshape_from_matrix(
                    prev_output, input_shape)
            else:
                final_output = prev_output

        if num_classes:
            dense_layer = tf.layers.Dense(
                num_classes,
                activation=None,
                kernel_initializer=tf.orthogonal_initializer()
            )
            logits = dense_layer(final_output)
        else:
            logits = final_output
        return logits

    def train_eval(self, features, hidden_feature, mode, problem_name):

        # prepare inputs to attention
        key = 'ori_seq' if self.params.label_transfer else 'seq'
        encoder_output = hidden_feature[key]

        label_ids = features['%s_label_ids' % problem_name]
        input_mask = features['input_mask']
        num_classes = self.params.num_classes[problem_name]

        if self.params.problem_type[problem_name] == 'seq2seq_text':
            embed_table = hidden_feature['embed_table']
        else:
            embed_table = tf.get_variable(
                'tag_embed_table', shape=[
                    num_classes, self.params.mask_lm_hidden_size],
                initializer=tf.orthogonal_initializer())
        decoder_inputs = tf.nn.embedding_lookup(
            embed_table, label_ids)

        # with tf.name_scope("shift_targets"):
        #     # Shift targets to the right, and remove the last element
        #     decoder_inputs = tf.pad(
        #         decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]

        decoder_inputs = modeling.embedding_postprocessor(
            input_tensor=decoder_inputs,
            use_token_type=False,
            use_position_embeddings=True,
            position_embedding_name="position_embeddings",
            initializer_range=self.params.bert_config.initializer_range,
            max_position_embeddings=self.params.bert_config.max_position_embeddings,
            dropout_prob=self.params.bert_config.hidden_dropout_prob)

        # attention_mask = modeling.create_attention_mask_from_input_mask(
        #     label_ids, input_mask)
        label_mask = tf.expand_dims(
            tf.cast(features['%s_mask' % problem_name], tf.float32), axis=1)
        decoder_self_attention_mask = label_mask * self.get_decoder_self_attention_mask(
            self.params.decode_max_seq_len)

        decode_output = self.decode(
            decoder_inputs=decoder_inputs,
            encoder_output=encoder_output,
            input_mask=input_mask,
            decoder_self_attention_mask=decoder_self_attention_mask,
            cache=None,
            num_classes=num_classes,
            do_return_all_layers=False
        )
        return decode_output


def attention_layer_with_cache(from_tensor,
                               to_tensor,
                               attention_mask=None,
                               num_attention_heads=1,
                               size_per_head=512,
                               query_act=None,
                               key_act=None,
                               value_act=None,
                               attention_probs_dropout_prob=0.0,
                               initializer_range=0.02,
                               do_return_2d_tensor=False,
                               batch_size=None,
                               from_seq_length=None,
                               to_seq_length=None,
                               decoder_self_attention_mask=None,
                               cache=None):
    """
    This is a modification of attention layer from bert to support
    fast decode.

    Performs multi-headed attention from `from_tensor` to `to_tensor`.

    This is an implementation of multi-headed attention based on "Attention
    is all you Need". If `from_tensor` and `to_tensor` are the same, then
    this is self-attention. Each timestep in `from_tensor` attends to the
    corresponding sequence in `to_tensor`, and returns a fixed-with vector.

    This function first projects `from_tensor` into a "query" tensor and
    `to_tensor` into "key" and "value" tensors. These are (effectively) a list
    of tensors of length `num_attention_heads`, where each tensor is of shape
    [batch_size, seq_length, size_per_head].

    Then, the query and key tensors are dot-producted and scaled. These are
    softmaxed to obtain attention probabilities. The value tensors are then
    interpolated by these probabilities, then concatenated back to a single
    tensor and returned.

    In practice, the multi-headed attention are done with transposes and
    reshapes rather than actual separate tensors.

    Args:
      from_tensor: float Tensor of shape [batch_size, from_seq_length,
        from_width].
      to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
      attention_mask: (optional) int32 Tensor of shape [batch_size,
        from_seq_length, to_seq_length]. The values should be 1 or 0. The
        attention scores will effectively be set to -infinity for any positions in
        the mask that are 0, and will be unchanged for positions that are 1.
      num_attention_heads: int. Number of attention heads.
      size_per_head: int. Size of each attention head.
      query_act: (optional) Activation function for the query transform.
      key_act: (optional) Activation function for the key transform.
      value_act: (optional) Activation function for the value transform.
      attention_probs_dropout_prob: (optional) float. Dropout probability of the
        attention probabilities.
      initializer_range: float. Range of the weight initializer.
      do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
        * from_seq_length, num_attention_heads * size_per_head]. If False, the
        output will be of shape [batch_size, from_seq_length, num_attention_heads
        * size_per_head].
      batch_size: (Optional) int. If the input is 2D, this might be the batch size
        of the 3D version of the `from_tensor` and `to_tensor`.
      from_seq_length: (Optional) If the input is 2D, this might be the seq length
        of the 3D version of the `from_tensor`.
      to_seq_length: (Optional) If the input is 2D, this might be the seq length
        of the 3D version of the `to_tensor`.

    Returns:
      float Tensor of shape [batch_size, from_seq_length,
        num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
        true, this will be of shape [batch_size * from_seq_length,
        num_attention_heads * size_per_head]).

    Raises:
      ValueError: Any of the arguments or tensor shapes are invalid.
    """

    def transpose_for_scores(input_tensor, batch_size, num_attention_heads,
                             seq_length, width):
        output_tensor = tf.reshape(
            input_tensor, [batch_size, seq_length, num_attention_heads, width])

        output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
        return output_tensor

    from_shape = modeling.get_shape_list(from_tensor, expected_rank=[2, 3])
    to_shape = modeling.get_shape_list(to_tensor, expected_rank=[2, 3])

    if len(from_shape) != len(to_shape):
        raise ValueError(
            "The rank of `from_tensor` must match the rank of `to_tensor`.")

    if len(from_shape) == 3:
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]
        to_seq_length = to_shape[1]
    elif len(from_shape) == 2:
        if (batch_size is None or from_seq_length is None or to_seq_length is None):
            raise ValueError(
                "When passing in rank 2 tensors to attention_layer, the values "
                "for `batch_size`, `from_seq_length`, and `to_seq_length` "
                "must all be specified.")

    # Scalar dimensions referenced here:
    #   B = batch size (number of sequences)
    #   F = `from_tensor` sequence length
    #   T = `to_tensor` sequence length
    #   N = `num_attention_heads`
    #   H = `size_per_head`

    from_tensor_2d = modeling.reshape_to_matrix(from_tensor)
    to_tensor_2d = modeling.reshape_to_matrix(to_tensor)

    # `query_layer` = [B*F, N*H]
    query_layer = tf.layers.dense(
        from_tensor_2d,
        num_attention_heads * size_per_head,
        activation=query_act,
        name="query",
        kernel_initializer=modeling.create_initializer(initializer_range))

    # `key_layer` = [B*T, N*H]
    key_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation=key_act,
        name="key",
        kernel_initializer=modeling.create_initializer(initializer_range))

    # `value_layer` = [B*T, N*H]
    value_layer = tf.layers.dense(
        to_tensor_2d,
        num_attention_heads * size_per_head,
        activation=value_act,
        name="value",
        kernel_initializer=modeling.create_initializer(initializer_range))

    if cache is not None:
        n_time_h = key_layer.get_shape()[1]

        key_layer_to_cache = tf.reshape(
            key_layer, [batch_size, -1, n_time_h])
        value_layer_to_cache = tf.reshape(
            value_layer, [batch_size, -1, n_time_h])
        # Combine cached keys and values with new keys and values.
        key_layer_from_cache = tf.concat(
            [cache["key_layer"], key_layer_to_cache], axis=1)
        value_layer_from_cache = tf.concat(
            [cache["value_layer"], value_layer_to_cache], axis=1)

        # update seq length
        # from_seq_length = key_layer_from_cache.get_shape()[1]
        from_seq_length = modeling.get_shape_list(
            key_layer_from_cache, expected_rank=[3])[1]
        to_seq_length = modeling.get_shape_list(
            value_layer_from_cache, expected_rank=[3])[1]

        # Update cache
        cache["key_layer"] = key_layer_from_cache
        cache["value_layer"] = value_layer_from_cache

        key_layer = tf.reshape(key_layer_from_cache, [-1, n_time_h])
        value_layer = tf.reshape(value_layer_from_cache, [-1, n_time_h])

    # `query_layer` = [B, N, F, H]
    # In self attention of decoder, the seq_length of q always be 1
    if cache is not None:
        query_layer = transpose_for_scores(
            query_layer, batch_size,
            num_attention_heads, 1,
            size_per_head)
    else:
        query_layer = transpose_for_scores(
            query_layer, batch_size,
            num_attention_heads, from_seq_length,
            size_per_head)

    # `key_layer` = [B, N, T, H]
    key_layer = transpose_for_scores(
        key_layer, batch_size, num_attention_heads,
        to_seq_length, size_per_head)

    # Take the dot product between "query" and "key" to get the raw
    # attention scores.
    # `attention_scores` = [B, N, F, T]
    attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
    attention_scores = tf.multiply(attention_scores,
                                   1.0 / math.sqrt(float(size_per_head)))

    if attention_mask is not None:
        # `attention_mask` = [B, 1, F, T]
        attention_mask = tf.expand_dims(attention_mask, axis=[1])

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0

        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        attention_scores += adder

    # Normalize the attention scores to probabilities.
    # `attention_probs` = [B, N, F, T]
    attention_probs = tf.nn.softmax(attention_scores)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = modeling.dropout(
        attention_probs, attention_probs_dropout_prob)

    # `value_layer` = [B, T, N, H]
    value_layer = tf.reshape(
        value_layer,
        [batch_size, to_seq_length, num_attention_heads, size_per_head])

    # `value_layer` = [B, N, T, H]
    value_layer = tf.transpose(value_layer, [0, 2, 1, 3])

    # `context_layer` = [B, N, F, H]
    context_layer = tf.matmul(attention_probs, value_layer)

    # `context_layer` = [B, F, N, H]
    context_layer = tf.transpose(context_layer, [0, 2, 1, 3])

    if do_return_2d_tensor:
        # `context_layer` = [B*F, N*V]
        context_layer = tf.reshape(
            context_layer,
            [batch_size * from_seq_length, num_attention_heads * size_per_head])
    else:
        # `context_layer` = [B, F, N*V]
        context_layer = tf.reshape(
            context_layer,
            [batch_size, from_seq_length, num_attention_heads * size_per_head])

    return context_layer