import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention


class Linear(keras.layers.Layer):
    def __init__(self, units, use_bias, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
        self.activation = keras.layers.ReLU()

    def call(self, x):
        """
        shapes:
            x: B x T x C
        """
        return self.activation(self.linear_layer(x))


class LinearBN(keras.layers.Layer):
    def __init__(self, units, use_bias, **kwargs):
        super(LinearBN, self).__init__(**kwargs)
        self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
        self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
        self.activation = keras.layers.ReLU()

    def call(self, x, training=None):
        """
        shapes:
            x: B x T x C
        """
        out = self.linear_layer(x)
        out = self.batch_normalization(out, training=training)
        return self.activation(out)


class Prenet(keras.layers.Layer):
    def __init__(self,
                 prenet_type,
                 prenet_dropout,
                 units,
                 bias,
                 **kwargs):
        super(Prenet, self).__init__(**kwargs)
        self.prenet_type = prenet_type
        self.prenet_dropout = prenet_dropout
        self.linear_layers = []
        if prenet_type == "bn":
            self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
        elif prenet_type == "original":
            self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
        else:
            raise RuntimeError(' [!] Unknown prenet type.')
        if prenet_dropout:
            self.dropout = keras.layers.Dropout(rate=0.5)

    def call(self, x, training=None):
        """
        shapes:
            x: B x T x C
        """
        for linear in self.linear_layers:
            if self.prenet_dropout:
                x = self.dropout(linear(x), training=training)
            else:
                x = linear(x)
        return x


def _sigmoid_norm(score):
    attn_weights = tf.nn.sigmoid(score)
    attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
    return attn_weights


class Attention(keras.layers.Layer):
    """TODO: implement forward_attention
    TODO: location sensitive attention
    TODO: implement attention windowing """
    def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
                 loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
                 use_trans_agent, use_forward_attn_mask, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.use_loc_attn = use_loc_attn
        self.loc_attn_n_filters = loc_attn_n_filters
        self.loc_attn_kernel_size = loc_attn_kernel_size
        self.use_windowing = use_windowing
        self.norm = norm
        self.use_forward_attn = use_forward_attn
        self.use_trans_agent = use_trans_agent
        self.use_forward_attn_mask = use_forward_attn_mask
        self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
        self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
        self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
        if use_loc_attn:
            self.location_conv1d = keras.layers.Conv1D(
                filters=loc_attn_n_filters,
                kernel_size=loc_attn_kernel_size,
                padding='same',
                use_bias=False,
                name='location_layer/location_conv1d')
            self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
        if norm == 'softmax':
            self.norm_func = tf.nn.softmax
        elif norm == 'sigmoid':
            self.norm_func = _sigmoid_norm
        else:
            raise ValueError("Unknown value for attention norm type")

    def init_states(self, batch_size, value_length):
        states = ()
        if self.use_loc_attn:
            attention_cum = tf.zeros([batch_size, value_length])
            attention_old = tf.zeros([batch_size, value_length])
            states = (attention_cum, attention_old)
        return states

    def process_values(self, values):
        """ cache values for decoder iterations """
        #pylint: disable=attribute-defined-outside-init
        self.processed_values = self.inputs_layer(values)
        self.values = values

    def get_loc_attn(self, query, states):
        """ compute location attention, query layer and
        unnorm. attention weights"""
        attention_cum, attention_old = states
        attn_cat = tf.stack([attention_old, attention_cum], axis=2)

        processed_query = self.query_layer(tf.expand_dims(query, 1))
        processed_attn = self.location_dense(self.location_conv1d(attn_cat))
        score = self.v(
            tf.nn.tanh(self.processed_values + processed_query +
                       processed_attn))
        score = tf.squeeze(score, axis=2)
        return score, processed_query

    def get_attn(self, query):
        """ compute query layer and unnormalized attention weights """
        processed_query = self.query_layer(tf.expand_dims(query, 1))
        score = self.v(tf.nn.tanh(self.processed_values + processed_query))
        score = tf.squeeze(score, axis=2)
        return score, processed_query

    def apply_score_masking(self, score, mask):  #pylint: disable=no-self-use
        """ ignore sequence paddings """
        padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
        # Bias so padding positions do not contribute to attention distribution.
        score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
        return score

    def call(self, query, states):
        """
        shapes:
            query: B x D
        """
        if self.use_loc_attn:
            score, _ = self.get_loc_attn(query, states)
        else:
            score, _ = self.get_attn(query)

        # TODO: masking
        # if mask is not None:
        # self.apply_score_masking(score, mask)
        # attn_weights shape == (batch_size, max_length, 1)

        attn_weights = self.norm_func(score)

        # update attention states
        if self.use_loc_attn:
            states = (states[0] + attn_weights, attn_weights)
        else:
            states = ()

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
        context_vector = tf.squeeze(context_vector, axis=1)
        return context_vector, attn_weights, states


# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
#     dtype = processed_query.dtype
#     num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
#     return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])


# class LocationSensitiveAttention(BahdanauAttention):
#     def __init__(self,
#                  units,
#                  memory=None,
#                  memory_sequence_length=None,
#                  normalize=False,
#                  probability_fn="softmax",
#                  kernel_initializer="glorot_uniform",
#                  dtype=None,
#                  name="LocationSensitiveAttention",
#                  location_attention_filters=32,
#                  location_attention_kernel_size=31):

#         super(LocationSensitiveAttention,
#                     self).__init__(units=units,
#                                     memory=memory,
#                                     memory_sequence_length=memory_sequence_length,
#                                     normalize=normalize,
#                                     probability_fn='softmax',  ## parent module default
#                                     kernel_initializer=kernel_initializer,
#                                     dtype=dtype,
#                                     name=name)
#         if probability_fn == 'sigmoid':
#             self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
#         self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
#         self.location_dense = keras.layers.Dense(units, use_bias=False)
#         # self.v = keras.layers.Dense(1, use_bias=True)

#     def  _location_sensitive_score(self, processed_query, keys, processed_loc):
#         processed_query = tf.expand_dims(processed_query, 1)
#         return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])

#     def _location_sensitive(self, alignment_cum, alignment_old):
#         alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
#         return self.location_dense(self.location_conv(alignment_cat))

#     def _sigmoid_normalization(self, score):
#         return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)

#     # def _apply_masking(self, score, mask):
#     #     padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
#     #     # Bias so padding positions do not contribute to attention distribution.
#     #     score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
#     #     return score

#     def _calculate_attention(self, query, state):
#         alignment_cum, alignment_old = state[:2]
#         processed_query = self.query_layer(
#             query) if self.query_layer else query
#         processed_loc = self._location_sensitive(alignment_cum, alignment_old)
#         score = self._location_sensitive_score(
#             processed_query,
#             self.keys,
#             processed_loc)
#         alignment = self.probability_fn(score, state)
#         alignment_cum = alignment_cum + alignment
#         state[0] = alignment_cum
#         state[1] = alignment
#         return alignment, state

#     def compute_context(self, alignments):
#         expanded_alignments = tf.expand_dims(alignments, 1)
#         context = tf.matmul(expanded_alignments, self.values)
#         context = tf.squeeze(context, [1])
#         return context

#     # def call(self, query, state):
#     #     alignment, next_state = self._calculate_attention(query, state)
#     #     return alignment, next_state