# -*- coding: utf-8 -*-
from __future__ import absolute_import
from functools import partial

from keras import backend as K
from keras_contrib import backend as KC
from keras import activations
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.layers import Layer
from keras.layers import InputSpec
from keras_contrib.utils.conv_utils import conv_output_length
from keras_contrib.utils.conv_utils import normalize_data_format
from keras_contrib.utils.test_utils import to_tuple
import numpy as np


class CosineConvolution2D(Layer):
    """Cosine Normalized Convolution operator for filtering
    windows of two-dimensional inputs.

    # Examples

    ```python
        # apply a 3x3 convolution with 64 output filters on a 256x256 image:
        model = Sequential()
        model.add(CosineConvolution2D(64, 3, 3,
                                padding='same',
                                input_shape=(3, 256, 256)))
        # now model.output_shape == (None, 64, 256, 256)

        # add a 3x3 convolution on top, with 32 output filters:
        model.add(CosineConvolution2D(32, 3, 3, padding='same'))
        # now model.output_shape == (None, 32, 256, 256)
    ```

    # Arguments
        filters: Number of convolution filters to use.
        kernel_size: kernel_size: An integer or tuple/list of
            2 integers, specifying the
            dimensions of the convolution window.
        init: name of initialization function for the weights of the layer
            (see [initializers](https://keras.io/initializers)), or alternatively,
            Theano function to use for weights initialization.
            This parameter is only relevant if you don't pass
            a `weights` argument.
        activation: name of activation function to use
            (see [activations](https://keras.io/activations)),
            or alternatively, elementwise Theano function.
            If you don't specify anything, no activation is applied
            (ie. "linear" activation: a(x) = x).
        weights: list of numpy arrays to set as initial weights.
        padding: 'valid', 'same' or 'full'
            ('full' requires the Theano backend).
        strides: tuple of length 2. Factor by which to strides output.
            Also called strides elsewhere.
        kernel_regularizer: instance of [WeightRegularizer](
            https://keras.io/regularizers)
            (eg. L1 or L2 regularization), applied to the main weights matrix.
        bias_regularizer: instance of [WeightRegularizer](
            https://keras.io/regularizers), applied to the use_bias.
        activity_regularizer: instance of [ActivityRegularizer](
            https://keras.io/regularizers), applied to the network output.
        kernel_constraint: instance of the [constraints](
            https://keras.io/constraints) module
            (eg. maxnorm, nonneg), applied to the main weights matrix.
        bias_constraint: instance of the [constraints](
            https://keras.io/constraints) module, applied to the use_bias.
        data_format: 'channels_first' or 'channels_last'.
            In 'channels_first' mode, the channels dimension
            (the depth) is at index 1, in 'channels_last' mode is it at index 3.
            It defaults to the `image_data_format` value found in your
            Keras config file at `~/.keras/keras.json`.
            If you never set it, then it will be `'channels_last'`.
        use_bias: whether to include a use_bias
            (i.e. make the layer affine rather than linear).

    # Input shape
        4D tensor with shape:
        `(samples, channels, rows, cols)` if data_format='channels_first'
        or 4D tensor with shape:
        `(samples, rows, cols, channels)` if data_format='channels_last'.

    # Output shape
        4D tensor with shape:
        `(samples, filters, nekernel_rows, nekernel_cols)`
        if data_format='channels_first'
        or 4D tensor with shape:
        `(samples, nekernel_rows, nekernel_cols, filters)`
        if data_format='channels_last'.
        `rows` and `cols` values might have changed due to padding.


    # References
        - [Cosine Normalization: Using Cosine Similarity Instead
           of Dot Product in Neural Networks](https://arxiv.org/pdf/1702.05870.pdf)
    """

    def __init__(self, filters, kernel_size,
                 kernel_initializer='glorot_uniform', activation=None, weights=None,
                 padding='valid', strides=(1, 1), data_format=None,
                 kernel_regularizer=None, bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None, bias_constraint=None,
                 use_bias=True, **kwargs):
        if data_format is None:
            data_format = K.image_data_format()
        if padding not in {'valid', 'same', 'full'}:
            raise ValueError('Invalid border mode for CosineConvolution2D:', padding)
        self.filters = filters
        self.kernel_size = kernel_size
        self.nb_row, self.nb_col = self.kernel_size
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.activation = activations.get(activation)
        self.padding = padding
        self.strides = tuple(strides)
        self.data_format = normalize_data_format(data_format)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.use_bias = use_bias
        self.input_spec = [InputSpec(ndim=4)]
        self.initial_weights = weights
        super(CosineConvolution2D, self).__init__(**kwargs)

    def build(self, input_shape):
        input_shape = to_tuple(input_shape)
        if self.data_format == 'channels_first':
            stack_size = input_shape[1]
            self.kernel_shape = (self.filters, stack_size, self.nb_row, self.nb_col)
            self.kernel_norm_shape = (1, stack_size, self.nb_row, self.nb_col)
        elif self.data_format == 'channels_last':
            stack_size = input_shape[3]
            self.kernel_shape = (self.nb_row, self.nb_col, stack_size, self.filters)
            self.kernel_norm_shape = (self.nb_row, self.nb_col, stack_size, 1)
        else:
            raise ValueError('Invalid data_format:', self.data_format)
        self.W = self.add_weight(shape=self.kernel_shape,
                                 initializer=partial(self.kernel_initializer),
                                 name='{}_W'.format(self.name),
                                 regularizer=self.kernel_regularizer,
                                 constraint=self.kernel_constraint)

        kernel_norm_name = '{}_kernel_norm'.format(self.name)
        self.kernel_norm = K.variable(np.ones(self.kernel_norm_shape),
                                      name=kernel_norm_name)

        if self.use_bias:
            self.b = self.add_weight(shape=(self.filters,),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.bias_regularizer,
                                     constraint=self.bias_constraint)
        else:
            self.b = None

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
        else:
            raise ValueError('Invalid data_format:', self.data_format)

        rows = conv_output_length(rows, self.nb_row,
                                  self.padding, self.strides[0])
        cols = conv_output_length(cols, self.nb_col,
                                  self.padding, self.strides[1])

        if self.data_format == 'channels_first':
            return input_shape[0], self.filters, rows, cols
        elif self.data_format == 'channels_last':
            return input_shape[0], rows, cols, self.filters

    def call(self, x, mask=None):
        b, xb = 0., 0.
        if self.data_format == 'channels_first':
            kernel_sum_axes = [1, 2, 3]
            if self.use_bias:
                b = K.reshape(self.b, (self.filters, 1, 1, 1))
                xb = 1.
        elif self.data_format == 'channels_last':
            kernel_sum_axes = [0, 1, 2]
            if self.use_bias:
                b = K.reshape(self.b, (1, 1, 1, self.filters))
                xb = 1.

        tmp = K.sum(K.square(self.W), axis=kernel_sum_axes, keepdims=True)
        Wnorm = K.sqrt(tmp + K.square(b) + K.epsilon())

        tmp = KC.conv2d(K.square(x), self.kernel_norm, strides=self.strides,
                        padding=self.padding,
                        data_format=self.data_format,
                        filter_shape=self.kernel_norm_shape)
        xnorm = K.sqrt(tmp + xb + K.epsilon())

        W = self.W / Wnorm

        output = KC.conv2d(x, W, strides=self.strides,
                           padding=self.padding,
                           data_format=self.data_format,
                           filter_shape=self.kernel_shape)

        if K.backend() == 'theano':
            xnorm = K.pattern_broadcast(xnorm, [False, True, False, False])

        output /= xnorm

        if self.use_bias:
            b /= Wnorm
            if self.data_format == 'channels_first':
                b = K.reshape(b, (1, self.filters, 1, 1))
            elif self.data_format == 'channels_last':
                b = K.reshape(b, (1, 1, 1, self.filters))
            else:
                raise ValueError('Invalid data_format:', self.data_format)
            b /= xnorm
            output += b
        output = self.activation(output)
        return output

    def get_config(self):
        config = {
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'activation': activations.serialize(self.activation),
            'padding': self.padding,
            'strides': self.strides,
            'data_format': self.data_format,
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint),
            'use_bias': self.use_bias}
        base_config = super(CosineConvolution2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


CosineConv2D = CosineConvolution2D