# Copyright 2019 Babylon Partners. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Relational graph convolution layer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.python.keras import activations
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import layers as keras_layers
from tensorflow.python.keras.engine import InputSpec

from rgat import layers as rgat_layers


class RelationalGraphConv(keras_layers.Layer):
    """Layer implementing Relational Graph Convolutions with sparse supports.
    Based upon https://arxiv.org/abs/1703.06103.

    Must be called using both inputs and support:
        inputs = get_inputs()
        support = get_support()

        rgc_layer = RelationalGraphConv(...)
        outputs = rgc_layer(inputs=inputs, support=support)

    Has aliases of `RGC` and `RelationalGraphConvolution`.

    Arguments:
        units (int): The dimensionality of the output space.
        relations (int): The number of relation types the layer will handle.
        kernel_basis_size (int): The number of basis kernels to create the
            relational kernels from, i.e. W_r = sum_i c_{i,r} W'_i, where
            r = 1, 2, ..., relations, and i = 1, 2 ..., kernel_basis_size.
            If `None` (default), these is no basis decomposition.
        activation (callable): Activation function. Set it to `None` to maintain
            a linear activation.
        use_bias (bool): Whether the layer uses a bias. Defaults to `False`.
        batch_normalisation (bool): Whether the layer uses batch normalisation.
            Defaults to `False`.
        kernel_initializer (callable): Initializer function for the graph
            convolution weight matrix. If None (default), weights are
            initialized using the `glorot_uniform` initializer.
        bias_initializer (callable): Initializer function for the bias. Defaults
            to `zeros`.
        kernel_regularizer (callable): Regularizer function for the graph
            convolution weight matrix. Defaults to `None`.
        bias_regularizer (callable): Regularizer function for the bias. Defaults
            to `None`.
        activity_regularizer (callable): Regularizer function for the output.
            Defaults to `None`.
        kernel_constraint (callable): An optional projection function to be
            applied to the kernel after being updated by an Optimizer (e.g. used
            to implement norm constraints or value constraints for layer
            weights). The function must take as input the unprojected variable
            and must return the projected variable (which must have the same
            shape). Constraints are not safe to use when doing asynchronous
            distributed training.
        bias_constraint (callable): An optional projection function to be
        applied to the bias after being updated by an Optimizer.
        feature_dropout (float): The dropout rate for node feature
            representations, between 0 and 1. E.g. rate=0.1 would drop out 10%
            of node input units.
        support_dropout (float): The dropout rate for edges in the support,
            between 0 and 1. E.g. rate=0.1 would drop out 10%
            of the edges in the support.
        name (str): The name of the layer. Defaults to
            `relational_graph_conv`.

    """
    def __init__(self,
                 units,
                 relations,
                 kernel_basis_size=None,
                 activation=None,
                 use_bias=False,
                 batch_normalisation=False,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 feature_dropout=None,
                 support_dropout=None,
                 name='relational_graph_conv',
                 **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)

        super(RelationalGraphConv, self).__init__(
            activity_regularizer=regularizers.get(activity_regularizer),
            name=name, **kwargs)

        self.units = int(units)
        self.relations = int(relations)
        self.kernel_basis_size = (int(kernel_basis_size)
                                  if kernel_basis_size is not None else None)
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.batch_normalisation = batch_normalisation
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.feature_dropout = feature_dropout
        self.support_dropout = support_dropout

        self.supports_masking = True
        self.input_spec = InputSpec(min_ndim=2)

        self.dense_layer = rgat_layers.BasisDecompositionDense(
            units=self.units * self.relations,
            basis_size=self.kernel_basis_size,
            coefficients_size=self.relations,
            use_bias=False,
            kernel_initializer=self.kernel_initializer,
            kernel_regularizer=self.kernel_regularizer,
            kernel_constraint=self.kernel_constraint,
            name=name + '_basis_decomposition_dense',
            **kwargs)
        if self.batch_normalisation:
            self.batch_normalisation_layer = tf.layers.BatchNormalization()

    def build(self, input_shape):
        input_shape = tf.TensorShape(input_shape)
        if input_shape[-1].value is None:
            raise ValueError(
                'The last dimension of the inputs to `RelationalGraphConv` '
                'should be defined. Found `None`.')
        self.input_spec = InputSpec(min_ndim=2,
                                    axes={-1: input_shape[-1].value})
        self.dense_layer.build(input_shape=input_shape)
        if self.use_bias:
            self.bias = self.add_variable('bias',
                                          shape=(self.units,),
                                          initializer=self.bias_initializer,
                                          regularizer=self.bias_regularizer,
                                          constraint=self.bias_constraint,
                                          dtype=self.dtype,
                                          trainable=True)
        else:
            self.bias = None
        self.built = True

    def call(self, inputs, support, training=False):
        if not isinstance(inputs, tf.SparseTensor):
            inputs = tf.convert_to_tensor(inputs, dtype=self.dtype)
            outputs = self.dense_layer(inputs)
        else:
            outputs = self.dense_layer.kernel                           # N,RF'
        outputs = tf.reshape(outputs, [-1, self.relations,
                                       self.units])                     # N,R,F'
        outputs = tf.transpose(outputs, perm=[1, 0, 2])                 # R,N,F'
        outputs = tf.reshape(outputs, (-1, self.units))                 # RN,F'

        if self.feature_dropout is not None:
            outputs = tf.nn.dropout(outputs,
                                    keep_prob=1 - self.feature_dropout,
                                    name="feature_dropout")
        if self.support_dropout is not None:
            support_values = tf.nn.dropout(support.values,
                                           keep_prob=1 - self.support_dropout,
                                           name="support_dropout")
            support = tf.SparseTensor(support.indices, support_values,
                                      support.dense_shape)
        outputs = tf.sparse_tensor_dense_matmul(support, outputs)       # N,F'
        if self.batch_normalisation:
            outputs = self.batch_normalisation_layer(outputs, training=training)
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)  # pylint: disable=not-callable
        return outputs

    def compute_output_shape(self, input_shape):
        return self.dense_layer.compute_output_shape(input_shape)

    def get_config(self):
        config = {
            'units': self.units,
            'relations': self.relations,
            'rank': self.kernel_basis_size,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'batch_normalisation': self.batch_normalisation,
            'kernel_initializer': initializers.serialize(
                self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(
                self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint),
            'feature_dropout': self.feature_dropout,
            'support_dropout': self.support_dropout
        }
        base_config = super(RelationalGraphConv, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


# Aliases

RGC = RelationalGraphConvolution = RelationalGraphConv