from keras.engine import Layer, InputSpec
from keras import initializers
from keras import regularizers
from keras import constraints
from keras import backend as K

from keras.utils.generic_utils import get_custom_objects

from keras.layers import BatchNormalization


class SwitchNormalization(Layer):
    """Switchable Normalization layer

    Switch Normalization performs Instance Normalization, Layer Normalization and Batch
    Normalization using its parameters, and then weighs them using learned parameters to
    allow different levels of interaction of the 3 normalization schemes for each layer.

    Only supports the moving average variant from the paper, since the `batch average`
    scheme requires dynamic graph execution to compute the mean and variance of several
    batches at runtime.

    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=1` in `BatchNormalization`.
        momentum: Momentum for the moving mean and the moving variance. The original
            implementation suggests a default momentum of `0.997`, however it is highly
            unstable and training can fail after a few epochs. To stabilise training, use
            lower values of momentum such as `0.99` or `0.98`.
        epsilon: Small float added to variance to avoid dividing by zero.
        final_gamma: Bool value to determine if this layer is the final
            normalization layer for the residual block.  Overrides the initialization
            of the scaling weights to be `zeros`. Only used for Residual Networks,
            to make the forward/backward signal initially propagated through an
            identity shortcut.
        center: If True, add offset of `beta` to normalized tensor.
            If False, `beta` is ignored.
        scale: If True, multiply by `gamma`.
            If False, `gamma` is not used.
            When the next layer is linear (also e.g. `nn.relu`),
            this can be disabled since the scaling
            will be done by the next layer.
        beta_initializer: Initializer for the beta weight.
        gamma_initializer: Initializer for the gamma weight.
        mean_weights_initializer: Initializer for the mean weights.
        variance_weights_initializer: Initializer for the variance weights.
        moving_mean_initializer: Initializer for the moving mean.
        moving_variance_initializer: Initializer for the moving variance.
        beta_regularizer: Optional regularizer for the beta weight.
        gamma_regularizer: Optional regularizer for the gamma weight.
        mean_weights_regularizer: Optional regularizer for the mean weights.
        variance_weights_regularizer: Optional regularizer for the variance weights.
        beta_constraint: Optional constraint for the beta weight.
        gamma_constraint: Optional constraint for the gamma weight.
        mean_weights_constraints: Optional constraint for the mean weights.
        variance_weights_constraints: Optional constraint for the variance weights.

    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.

    # Output shape
        Same shape as input.

    # References
        - [Differentiable Learning-to-Normalize via Switchable Normalization](https://arxiv.org/abs/1806.10779)
    """

    def __init__(self,
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-3,
                 final_gamma=False,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 mean_weights_initializer='ones',
                 variance_weights_initializer='ones',
                 moving_mean_initializer='ones',
                 moving_variance_initializer='zeros',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 mean_weights_regularizer=None,
                 variance_weights_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 mean_weights_constraints=None,
                 variance_weights_constraints=None,
                 **kwargs):
        super(SwitchNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale

        self.beta_initializer = initializers.get(beta_initializer)
        if final_gamma:
            self.gamma_initializer = initializers.get('zeros')
        else:
            self.gamma_initializer = initializers.get(gamma_initializer)
        self.mean_weights_initializer = initializers.get(mean_weights_initializer)
        self.variance_weights_initializer = initializers.get(variance_weights_initializer)
        self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        self.moving_variance_initializer = initializers.get(moving_variance_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.mean_weights_regularizer = regularizers.get(mean_weights_regularizer)
        self.variance_weights_regularizer = regularizers.get(variance_weights_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)
        self.mean_weights_constraints = constraints.get(mean_weights_constraints)
        self.variance_weights_constraints = constraints.get(variance_weights_constraints)

    def build(self, input_shape):
        dim = input_shape[self.axis]

        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')

        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                name='gamma',
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                name='beta',
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint)
        else:
            self.beta = None

        self.moving_mean = self.add_weight(
            shape=shape,
            name='moving_mean',
            initializer=self.moving_mean_initializer,
            trainable=False)

        self.moving_variance = self.add_weight(
            shape=shape,
            name='moving_variance',
            initializer=self.moving_variance_initializer,
            trainable=False)

        self.mean_weights = self.add_weight(
            shape=(3,),
            name='mean_weights',
            initializer=self.mean_weights_initializer,
            regularizer=self.mean_weights_regularizer,
            constraint=self.mean_weights_constraints)

        self.variance_weights = self.add_weight(
            shape=(3,),
            name='variance_weights',
            initializer=self.variance_weights_initializer,
            regularizer=self.variance_weights_regularizer,
            constraint=self.variance_weights_constraints)

        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)

        # Prepare broadcasting shape.
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]

        if self.axis != 0:
            del reduction_axes[0]

        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        mean_instance = K.mean(inputs, reduction_axes, keepdims=True)
        variance_instance = K.var(inputs, reduction_axes, keepdims=True)

        mean_layer = K.mean(mean_instance, self.axis, keepdims=True)
        temp = variance_instance + K.square(mean_instance)
        variance_layer = K.mean(temp, self.axis, keepdims=True) - K.square(mean_layer)

        def training_phase():
            mean_batch = K.mean(mean_instance, axis=0, keepdims=True)
            variance_batch = K.mean(temp, axis=0, keepdims=True) - K.square(mean_batch)

            mean_batch_reshaped = K.flatten(mean_batch)
            variance_batch_reshaped = K.flatten(variance_batch)

            if K.backend() != 'cntk':
                sample_size = K.prod([K.shape(inputs)[axis]
                                      for axis in reduction_axes])
                sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

                # sample variance - unbiased estimator of population variance
                variance_batch_reshaped *= sample_size / (sample_size - (1.0 + self.epsilon))

            self.add_update([K.moving_average_update(self.moving_mean,
                                                     mean_batch_reshaped,
                                                     self.momentum),
                             K.moving_average_update(self.moving_variance,
                                                     variance_batch_reshaped,
                                                     self.momentum)],
                            inputs)

            return normalize_func(mean_batch, variance_batch)

        def inference_phase():
            mean_batch = self.moving_mean
            variance_batch = self.moving_variance

            return normalize_func(mean_batch, variance_batch)

        def normalize_func(mean_batch, variance_batch):
            mean_batch = K.reshape(mean_batch, broadcast_shape)
            variance_batch = K.reshape(variance_batch, broadcast_shape)

            mean_weights = K.softmax(self.mean_weights, axis=0)
            variance_weights = K.softmax(self.variance_weights, axis=0)

            mean = (mean_weights[0] * mean_instance +
                    mean_weights[1] * mean_layer +
                    mean_weights[2] * mean_batch)

            variance = (variance_weights[0] * variance_instance +
                        variance_weights[1] * variance_layer +
                        variance_weights[2] * variance_batch)

            outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))

            if self.scale:
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                outputs = outputs * broadcast_gamma

            if self.center:
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                outputs = outputs + broadcast_beta

            return outputs

        if training in {0, False}:
            return inference_phase()

        return K.in_train_phase(training_phase,
                                inference_phase,
                                training=training)

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'momentum': self.momentum,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'mean_weights_initializer': initializers.serialize(self.mean_weights_initializer),
            'variance_weights_initializer': initializers.serialize(self.variance_weights_initializer),
            'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
            'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'mean_weights_regularizer': regularizers.serialize(self.mean_weights_regularizer),
            'variance_weights_regularizer': regularizers.serialize(self.variance_weights_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint),
            'mean_weights_constraints': constraints.serialize(self.mean_weights_constraints),
            'variance_weights_constraints': constraints.serialize(self.variance_weights_constraints),
        }
        base_config = super(SwitchNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape


get_custom_objects().update({'SwitchNormalization': SwitchNormalization})


if __name__ == '__main__':
    from keras.layers import Input
    from keras.models import Model
    ip = Input(shape=(None, None, 4))
    #ip = Input(batch_shape=(100, None, None, 2))
    x = SwitchNormalization(axis=-1)(ip)
    model = Model(ip, x)
    model.compile('adam', 'mse')
    model.summary()

    import numpy as np
    x = np.random.normal(0.0, 1.0, size=(10, 8, 8, 4))
    model.fit(x, x, epochs=5,)