import functools
import warnings

import numpy

import chainer
from chainer import configuration
from chainer import functions
from chainer import link
import chainer.serializer as serializer_mod
from chainer.utils import argument


class DecorrelatedBatchNormalization(link.Link):

    """Decorrelated batch normalization layer.

    This link wraps the
    :func:`~chainer.functions.decorrelated_batch_normalization` and
    :func:`~chainer.functions.fixed_decorrelated_batch_normalization`
    functions. It works on outputs of linear or convolution functions.

    It runs in three modes: training mode, fine-tuning mode, and testing mode.

    In training mode, it normalizes the input by *batch statistics*. It also
    maintains approximated population statistics by moving averages, which can
    be used for instant evaluation in testing mode.

    In fine-tuning mode, it accumulates the input to compute *population
    statistics*. In order to correctly compute the population statistics, a
    user must use this mode to feed mini-batches running through whole training
    dataset.

    In testing mode, it uses pre-computed population statistics to normalize
    the input variable. The population statistics is approximated if it is
    computed by training mode, or accurate if it is correctly computed by
    fine-tuning mode.

    Args:
        size (int or tuple of ints): Size (or shape) of channel
            dimensions.
        groups (int): Number of groups to use for group whitening.
        decay (float): Decay rate of moving average
            which is used during training.
        eps (float): Epsilon value for numerical stability.
        dtype (numpy.dtype): Type to use in computing.

    See: `Decorrelated Batch Normalization <https://arxiv.org/abs/1804.08450>`_

    .. seealso::
       :func:`~chainer.functions.decorrelated_batch_normalization`,
       :func:`~chainer.functions.fixed_decorrelated_batch_normalization`

    Attributes:
        avg_mean (:ref:`ndarray`): Population mean.
        avg_projection (:ref:`ndarray`): Population
            projection.
        groups (int): Number of groups to use for group whitening.
        N (int): Count of batches given for fine-tuning.
        decay (float): Decay rate of moving average
            which is used during training.
        ~DecorrelatedBatchNormalization.eps (float): Epsilon value for
            numerical stability. This value is added to the batch variances.

    """

    def __init__(self, size, groups=16, decay=0.9, eps=2e-5,
                 dtype=numpy.float32):
        super(DecorrelatedBatchNormalization, self).__init__()
        C = size // groups
        self.avg_mean = numpy.zeros((groups, C), dtype=dtype)
        self.register_persistent('avg_mean')
        avg_projection = numpy.zeros((groups, C, C), dtype=dtype)
        arange_C = numpy.arange(C)
        avg_projection[:, arange_C, arange_C] = 1
        self.avg_projection = avg_projection
        self.register_persistent('avg_projection')
        self.N = 0
        self.register_persistent('N')
        self.decay = decay
        self.eps = eps
        self.groups = groups

    def serialize(self, serializer):
        if isinstance(serializer, serializer_mod.Deserializer):
            serializer = _PatchedDeserializer(serializer, {
                'avg_mean': functools.partial(
                    fix_avg_mean, groups=self.groups),
                'avg_projection': functools.partial(
                    fix_avg_projection, groups=self.groups),
            })
        super(DecorrelatedBatchNormalization, self).serialize(serializer)

    def forward(self, x, **kwargs):
        """forward(self, x, *, finetune=False)

        Invokes the forward propagation of DecorrelatedBatchNormalization.

        In training mode, the DecorrelatedBatchNormalization computes moving
        averages of the mean and projection for evaluation during training,
        and normalizes the input using batch statistics.

        Args:
            x (:class:`~chainer.Variable`): Input variable.
            finetune (bool): If it is in the training mode and ``finetune`` is
                ``True``, DecorrelatedBatchNormalization runs in fine-tuning
                mode; it accumulates the input array to compute population
                statistics for normalization, and normalizes the input using
                batch statistics.

        """
        finetune, = argument.parse_kwargs(kwargs, ('finetune', False))

        if configuration.config.train:
            if finetune:
                self.N += 1
                decay = 1. - 1. / self.N
            else:
                decay = self.decay

            avg_mean = self.avg_mean
            avg_projection = self.avg_projection

            if configuration.config.in_recomputing:
                # Do not update statistics when extra forward computation is
                # called.
                if finetune:
                    self.N -= 1
                avg_mean = None
                avg_projection = None

            ret = functions.decorrelated_batch_normalization(
                x, groups=self.groups, eps=self.eps,
                running_mean=avg_mean, running_projection=avg_projection,
                decay=decay)
        else:
            # Use running average statistics or fine-tuned statistics.
            mean = self.avg_mean
            projection = self.avg_projection
            ret = functions.fixed_decorrelated_batch_normalization(
                x, mean, projection, groups=self.groups)
        return ret

    def start_finetuning(self):
        """Resets the population count for collecting population statistics.

        This method can be skipped if it is the first time to use the
        fine-tuning mode. Otherwise, this method should be called before
        starting the fine-tuning mode again.

        """
        self.N = 0


class _PatchedDeserializer(serializer_mod.Deserializer):

    def __init__(self, base, patches):
        self.base = base
        self.patches = patches

    def __repr__(self):
        return '_PatchedDeserializer({}, {})'.format(
            repr(self.base), repr(self.patches))

    def __call__(self, key, value):
        if key not in self.patches:
            return self.base(key, value)
        arr = self.base(key, None)
        arr = self.patches[key](arr)
        if value is None:
            return arr
        chainer.backend.copyto(value, arr)
        return value


def _warn_old_model():
    msg = (
        'Found moving statistics of old DecorrelatedBatchNormalization, whose '
        'algorithm was different from the paper.')
    warnings.warn(msg)


def fix_avg_mean(avg_mean, groups):
    if avg_mean.ndim == 2:  # OK
        return avg_mean
    elif avg_mean.ndim == 1:  # Issue #7706
        if groups != 1:
            _warn_old_model()
        return _broadcast_to(avg_mean, (groups,) + avg_mean.shape)
    raise ValueError('unexpected shape of avg_mean')


def fix_avg_projection(avg_projection, groups):
    if avg_projection.ndim == 3:  # OK
        return avg_projection
    elif avg_projection.ndim == 2:  # Issue #7706
        if groups != 1:
            _warn_old_model()
        return _broadcast_to(
            avg_projection, (groups,) + avg_projection.shape)
    raise ValueError('unexpected shape of avg_projection')


def _broadcast_to(array, shape):
    if hasattr(numpy, 'broadcast_to'):
        return numpy.broadcast_to(array, shape)
    else:
        # numpy 1.9 doesn't support broadcast_to method
        dummy = numpy.empty(shape)
        bx, _ = numpy.broadcast_arrays(array, dummy)
        return bx