# -*- coding: utf-8 -*-
# @Author  : DevinYang(pistonyang@gmail.com)
__all__ = ['SwitchNorm2d', 'SwitchNorm3d', 'EvoNormB0', 'EvoNormS0']

import torch
from torch import nn
from . import functional as F


class _SwitchNorm(nn.Module):
    """
    Avoid to feed 1xCxHxW and NxCx1x1 data to this.
    """
    _version = 2

    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True):
        super(_SwitchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(num_features))
            self.bias = nn.Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def _check_input_dim(self, x):
        raise NotImplementedError

    def forward(self, x):
        self._check_input_dim(x)
        return F.switch_norm(
            x,
            self.running_mean,
            self.running_var,
            self.weight,
            self.bias,
            self.mean_weight,
            self.var_weight,
            self.training,
            self.momentum,
            self.eps)


class SwitchNorm2d(_SwitchNorm):
    def _check_input_dim(self, x):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))


class SwitchNorm3d(_SwitchNorm):
    def _check_input_dim(self, x):
        if x.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(x.dim()))


class _EvoNorm(nn.Module):
    def __init__(self, prefix, num_features, eps=1e-5, momentum=0.9, groups=32,
                 affine=True):
        super(_EvoNorm, self).__init__()
        assert prefix in ('s0', 'b0')
        self.prefix = prefix
        self.groups = groups
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
            self.register_parameter('v', None)
        self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
        self.reset_parameters()

    def reset_parameters(self):
        if self.affine:
            torch.nn.init.ones_(self.weight)
            torch.nn.init.zeros_(self.bias)
            torch.nn.init.ones_(self.v)

    def _check_input_dim(self, x):
        if x.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(x.dim()))

    def forward(self, x):
        self._check_input_dim(x)
        return F.evo_norm(x, self.prefix, self.running_var, self.v,
                          self.weight, self.bias, self.training,
                          self.momentum, self.eps, self.groups)


class EvoNormB0(_EvoNorm):
    def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True):
        super(EvoNormB0, self).__init__('b0', num_features, eps, momentum,
                                        affine=affine)


class EvoNormS0(_EvoNorm):
    def __init__(self, num_features, groups=32, affine=True):
        super(EvoNormS0, self).__init__('s0', num_features, groups=groups,
                                        affine=affine)