# Copyright 2019 Babylon Partners. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Relational graph attention layer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.framework import tensor_shape

from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import layers as keras_layers
from tensorflow.python.keras.engine import InputSpec

from rgat import layers as rgat_layers
from rgat.ops import math_ops as rgat_math_ops
from rgat.layers.graph_utils import HeadAggregation
from rgat.layers.graph_utils import AttentionModes
from rgat.layers.graph_utils import AttentionStyles

from .relational_graph_attention_logits import RelationalGraphAttentionLogits


class RelationalGraphAttention(keras_layers.Layer):
    """Layer implementing Relational Graph Attention of
    https://openreview.net/forum?id=Bklzkh0qFm with sparse supports.

    Must be called using both inputs and support:
        inputs = get_inputs()
        support = get_support()

        rgat_layer = RelationalGraphAttention(...)
        outputs = rgat_layer(inputs=inputs, support=support)

    Has alias of `RGAT`.

    Arguments:
        units (int): The dimensionality of the output space.
        relations (int): The number of relation types the layer will handle.
        heads (int): The number of attention heads to use (see
            https://arxiv.org/abs/1710.10903). Defaults to `1`.
        head_aggregation (str): The attention head aggregation method to use
            (see https://arxiv.org/abs/1710.10903). Can be one of `'mean'` or
            `'concat'`. Defaults to `'mean'`.
        attention_mode (str): The relational attention mode to to use (see
            https://openreview.net/forum?id=Bklzkh0qFm). Can be one of `'argat'`
            or `'wirgat'`. Defaults to `'argat'`.
        attention_style (str): The different types of attention to use. To use
            the transformer style multiplicative attention, set to `'dot'`.  To
            use the GAT style additive attention set to `'sum'`. Defaults to
            `'sum'`.
        attention_units (int): The dimensionality of the attention space. If
            using `'sum'` style attention, this must be set to `1`.
        attn_use_edge_features (bool): Whether the layer can use edge features.
            Defaults to `False`.
        kernel_basis_size (int): The number of basis kernels to create the
            relational kernels from, i.e. W_r = sum_i c_{i,r} W'_i, where
            r = 1, 2, ..., relations, and i = 1, 2 ..., kernel_basis_size.
            If `None` (default), these is no basis decomposition.
        attn_kernel_basis_size (int): The number of basis kernels to create the
            relational attention kernels from. Defaults to `None`.
        activation (callable): Activation function. Set it to `None` to maintain
            a linear activation.
        attn_activation (callable): Activation function to apply to the
            attention logits prior to feeding to softmax. Defaults to the leaky
            relu in https://arxiv.org/abs/1710.10903, however, when using
            `'dot'` style attention, this can be set to `None`.
        use_bias (bool): Whether the layer uses a bias. Defaults to `False`.
        batch_normalisation (bool): Whether the layer uses batch normalisation.
            Defaults to `False`.
        kernel_initializer (callable): Initializer function for the graph
            convolution weight matrix. If None (default), weights are
            initialized using the `glorot_uniform` initializer.
        bias_initializer (callable): Initializer function for the bias. Defaults
            to `zeros`.
        attn_kernel_initializer (callable): Initializer function for the
            attention weight matrix. If None (default), weights are
            initialized using the `glorot_uniform` initializer.
        kernel_regularizer (callable): Regularizer function for the graph
            convolution weight matrix. Defaults to `None`.
        bias_regularizer (callable): Regularizer function for the bias. Defaults
            to `None`.
        attn_kernel_regularizer (callable): Regularizer function for the graph
            attention weight matrix. Defaults to `None`.
        activity_regularizer (callable): Regularizer function for the output.
            Defaults to `None`.
        feature_dropout (float): The dropout rate for node feature
            representations, between 0 and 1. E.g. rate=0.1 would drop out 10%
            of node input units.
        support_dropout (float): The dropout rate for edges in the support,
            between 0 and 1. E.g. rate=0.1 would drop out 10%
            of the edges in the support.
        edge_feature_dropout (float): The dropout rate for edge feature
            representations, between 0 and 1.
        name (string): The name of the layer. Defaults to
            `rgat`.

    """
    def __init__(self,
                 units,
                 relations,
                 heads=1,
                 head_aggregation=HeadAggregation.MEAN,
                 attention_mode=AttentionModes.ARGAT,
                 attention_style=AttentionStyles.SUM,
                 attention_units=1,
                 attn_use_edge_features=False,
                 kernel_basis_size=None,
                 attn_kernel_basis_size=None,
                 activation=None,
                 attn_activation=tf.nn.leaky_relu,
                 use_bias=False,
                 batch_normalisation=False,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 attn_kernel_initializer='glorot_uniform',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 attn_kernel_regularizer=None,
                 activity_regularizer=None,
                 feature_dropout=None,
                 support_dropout=None,
                 edge_feature_dropout=None,
                 name='rgat',
                 **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)

        super(RelationalGraphAttention, self).__init__(
            activity_regularizer=regularizers.get(activity_regularizer),
            name=name, **kwargs)

        self.units = int(units)
        self.relations = int(relations)
        self.heads = int(heads)
        self.head_aggregation = HeadAggregation.validate(head_aggregation)
        self.attention_mode = AttentionModes.validate(attention_mode)
        self.attention_style = AttentionStyles.validate(attention_style)
        self.attention_units = attention_units
        self.attn_use_edge_features = attn_use_edge_features

        self.kernel_basis_size = (int(kernel_basis_size)
                                  if kernel_basis_size else None)
        self.attn_kernel_basis_size = (int(attn_kernel_basis_size)
                                       if attn_kernel_basis_size else None)

        self.activation = activations.get(activation)
        self.attn_activation = activations.get(attn_activation)

        self.use_bias = use_bias
        self.batch_normalisation = batch_normalisation

        if self.batch_normalisation:
            self.batch_normalisation_layer = tf.layers.BatchNormalization()

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)

        self.feature_dropout = feature_dropout
        self.support_dropout = support_dropout
        self.edge_feature_dropout = edge_feature_dropout

        self.supports_masking = True
        self.input_spec = InputSpec(min_ndim=2)

        self.dense_layer = rgat_layers.BasisDecompositionDense(
            units=self.relations * self.heads * self.units,
            basis_size=self.kernel_basis_size,
            coefficients_size=self.relations * self.heads,
            use_bias=False,
            kernel_initializer=self.kernel_initializer,
            kernel_regularizer=self.kernel_regularizer,
            name=name + '_basis_decomposition_dense',
            **kwargs)
        self.attention_logits = RelationalGraphAttentionLogits(
            relations=self.relations,
            heads=self.heads,
            attention_style=self.attention_style,
            attention_units=self.attention_units,
            basis_size=self.attn_kernel_basis_size,
            activation=self.attn_activation,
            use_edge_features=self.attn_use_edge_features,
            kernel_initializer=self.attn_kernel_initializer,
            kernel_regularizer=self.attn_kernel_regularizer,
            feature_dropout=self.feature_dropout,
            edge_feature_dropout=self.edge_feature_dropout,
            batch_normalisation=self.batch_normalisation,
            name="logits",
            **kwargs)
        if self.head_aggregation == HeadAggregation.PROJECTION:
            self.projection_layer = keras_layers.Dense(
                units=self.units,
                use_bias=False,
                kernel_initializer=self.kernel_initializer,
                kernel_regularizer=self.kernel_regularizer,
                name="projection",
                **kwargs)
        if self.batch_normalisation:
            self.batch_normalisation_layer = tf.layers.BatchNormalization()

    def build(self, input_shape):
        input_shape = tf.TensorShape(input_shape)
        if input_shape[-1].value is None:
            raise ValueError(
                'The last dimension of the inputs to `RelationalGraphConv` '
                'should be defined. Found `None`.')
        self.input_spec = InputSpec(min_ndim=2,
                                    axes={-1: input_shape[-1].value})
        self.dense_layer.build(input_shape=input_shape)
        if self.use_bias:
            bias_size = self.units
            if self.head_aggregation == HeadAggregation.CONCAT:
                bias_size = bias_size * self.heads
            self.bias = self.add_variable('bias',
                                          shape=(bias_size,),
                                          initializer=self.bias_initializer,
                                          regularizer=self.bias_regularizer,
                                          dtype=self.dtype,
                                          trainable=True)
        else:
            self.bias = None
        self.built = True

    def _attn_r_n_m_h(self):
        h, r, n = self.heads, self.relations, self._nodes

        attn_h_n_rm = self._attn_h_n_rm
        attn_h_n_r_m = tf.sparse_reshape(attn_h_n_rm, [h, n, r, n])
        attn_r_n_m_h = tf.sparse_transpose(attn_h_n_r_m, [2, 1, 3, 0])

        return attn_r_n_m_h

    def call(self, inputs, support, edge_features=None, training=False,
             attention_off=False):
        with tf.name_scope("project_features"):
            if not isinstance(inputs, tf.SparseTensor):
                inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
                self.pre_aggregation = outputs_n_rhu = self.dense_layer(inputs)
            else:
                self.pre_aggregation = outputs_n_rhu = self.dense_layer.kernel

        # support is N,RN

        # shorthands
        self._nodes = n = rgat_math_ops.get_shape(outputs_n_rhu)[0]
        r, h, u = self.relations, self.heads, self.units

        # n corresponds to `to nodes`   (of size n)
        # m corresponds to `from nodes` (of size n)

        self.logits = logits_h_n_rm = self.attention_logits(
            outputs_n_rhu, support, edge_features=edge_features,
            as_sparse_tensor=True, training=training)

        if attention_off:
            tf.logging.warning(
                "You are zeroing the attention mechanism, are you sure?")
            logits_h_n_rm = tf.SparseTensor(self.logits.indices,
                                            tf.zeros_like(logits_h_n_rm.values),
                                            self.logits.dense_shape)

        with tf.name_scope("attention_coefficients_{}".format(
                self.attention_mode)):
            if self.attention_mode == AttentionModes.ARGAT:
                attn_h_n_rm = tf.sparse_softmax(logits_h_n_rm)
            else:
                logits_h_n_r_m = tf.sparse_reshape(logits_h_n_rm,
                                                   [h, n, r, n])
                attn_h_n_r_m = tf.sparse_softmax(logits_h_n_r_m)
                attn_h_n_rm = tf.sparse_reshape(attn_h_n_r_m, [h, n, r * n])

        self._attn_h_n_rm = attn_h_n_rm

        with tf.name_scope("transform_project_features_h_rm_u"):
            outputs_m_r_h_u = tf.reshape(outputs_n_rhu, [n, r, h, u])
            outputs_h_r_m_u = tf.transpose(outputs_m_r_h_u, [2, 1, 0, 3])
            outputs_h_rm_u = self.pre_attn = tf.reshape(outputs_h_r_m_u,
                                                        [h, r * n, u])

        if self.support_dropout is not None:
            attn_values = tf.nn.dropout(attn_h_n_rm.values,
                                        keep_prob=1 - self.support_dropout,
                                        name="support_dropout")
            attn_h_n_rm = tf.SparseTensor(attn_h_n_rm.indices, attn_values,
                                          attn_h_n_rm.dense_shape)

        if self.feature_dropout is not None:
            outputs_h_rm_u = tf.nn.dropout(outputs_h_rm_u,
                                           keep_prob=1 - self.feature_dropout,
                                           name="feature_dropout")

        outputs_h_n_u = rgat_math_ops.batched_sparse_dense_matmul(
            sparse_tensor=attn_h_n_rm, dense_tensor=outputs_h_rm_u)

        with tf.name_scope("head_aggregation_{}".format(self.head_aggregation)):
            if self.head_aggregation == HeadAggregation.MEAN:
                outputs = tf.reduce_mean(outputs_h_n_u, axis=0)
            elif self.head_aggregation == HeadAggregation.SUM:
                outputs = tf.reduce_sum(outputs_h_n_u, axis=0)
            elif self.head_aggregation == HeadAggregation.CONCAT:
                outputs_n_h_u = tf.transpose(outputs_h_n_u, [1, 0, 2])
                outputs = tf.reshape(outputs_n_h_u, [n, h * u])
            elif self.head_aggregation == HeadAggregation.PROJECTION:
                outputs_n_h_u = tf.transpose(outputs_h_n_u, [1, 0, 2])
                outputs_n_hu = tf.reshape(outputs_n_h_u, [n, h * u])
                outputs = self.projection_layer(outputs_n_hu)

        if self.batch_normalisation:
            outputs = self.batch_normalisation_layer(outputs, training=training)
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)  # pylint: disable=not-callable
        return outputs

    def compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        input_shape = input_shape.with_rank_at_least(2)
        if input_shape[-1].value is None:
            raise ValueError(
                'The innermost dimension of input_shape must be defined, '
                'but saw: %s'
                % input_shape)

        output_size = self.units
        if self.head_aggregation == HeadAggregation.CONCAT:
            output_size = output_size * self.heads

        return input_shape[:-1].concatenate(output_size)

    def get_config(self):
        config = {
            'units': self.units,
            'relations': self.relations,
            'heads': self.heads,
            'head_aggregation': self.head_aggregation,
            'attention_mode': self.attention_mode,
            'attention_style': self.attention_style,
            'attention_units': self.attention_units,
            'attn_use_edge_features': self.attn_use_edge_features,
            'kernel_basis_size': self.kernel_basis_size,
            'attn_kernel_basis_size': self.attn_kernel_basis_size,
            'activation': activations.serialize(self.activation),
            'attn_activation': activations.serialize(self.attn_activation),
            'use_bias': self.use_bias,
            'batch_normalisation': self.batch_normalisation,
            'kernel_initializer': initializers.serialize(
                self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'attn_kernel_initializer': initializers.serialize(
                self.attn_kernel_initializer),
            'kernel_regularizer': regularizers.serialize(
                self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'attn_kernel_regularizer': regularizers.serialize(
                self.attn_kernel_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'feature_dropout': self.feature_dropout,
            'support_dropout': self.support_dropout,
            'edge_feature_dropout': self.edge_feature_dropout
        }
        base_config = super(RelationalGraphAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


# Aliases

RGAT = RelationalGraphAttention