import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

from .utils import _pair
from .mixed_lipschitz import InducedNormLinear, InducedNormConv2d

__all__ = ['SpectralNormLinear', 'SpectralNormConv2d', 'LopLinear', 'LopConv2d', 'get_linear', 'get_conv2d']


class SpectralNormLinear(nn.Module):

    def __init__(
        self, in_features, out_features, bias=True, coeff=0.97, n_iterations=None, atol=None, rtol=None, **unused_kwargs
    ):
        del unused_kwargs
        super(SpectralNormLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.coeff = coeff
        self.n_iterations = n_iterations
        self.atol = atol
        self.rtol = rtol
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

        h, w = self.weight.shape
        self.register_buffer('scale', torch.tensor(0.))
        self.register_buffer('u', F.normalize(self.weight.new_empty(h).normal_(0, 1), dim=0))
        self.register_buffer('v', F.normalize(self.weight.new_empty(w).normal_(0, 1), dim=0))
        self.compute_weight(True, 200)

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def compute_weight(self, update=True, n_iterations=None, atol=None, rtol=None):
        n_iterations = self.n_iterations if n_iterations is None else n_iterations
        atol = self.atol if atol is None else atol
        rtol = self.rtol if rtol is None else atol

        if n_iterations is None and (atol is None or rtol is None):
            raise ValueError('Need one of n_iteration or (atol, rtol).')

        if n_iterations is None:
            n_iterations = 20000

        u = self.u
        v = self.v
        weight = self.weight
        if update:
            with torch.no_grad():
                itrs_used = 0.
                for _ in range(n_iterations):
                    old_v = v.clone()
                    old_u = u.clone()
                    # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                    # are the first left and right singular vectors.
                    # This power iteration produces approximations of `u` and `v`.
                    v = F.normalize(torch.mv(weight.t(), u), dim=0, out=v)
                    u = F.normalize(torch.mv(weight, v), dim=0, out=u)
                    itrs_used = itrs_used + 1
                    if atol is not None and rtol is not None:
                        err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
                        err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
                        tol_u = atol + rtol * torch.max(u)
                        tol_v = atol + rtol * torch.max(v)
                        if err_u < tol_u and err_v < tol_v:
                            break
                if itrs_used > 0:
                    u = u.clone()
                    v = v.clone()

        sigma = torch.dot(u, torch.mv(weight, v))
        with torch.no_grad():
            self.scale.copy_(sigma)
        # soft normalization: only when sigma larger than coeff
        factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
        weight = weight / factor
        return weight

    def forward(self, input):
        weight = self.compute_weight(update=self.training)
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, coeff={}, n_iters={}, atol={}, rtol={}'.format(
            self.in_features, self.out_features, self.bias is not None, self.coeff, self.n_iterations, self.atol,
            self.rtol
        )


class SpectralNormConv2d(nn.Module):

    def __init__(
        self, in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, n_iterations=None,
        atol=None, rtol=None, **unused_kwargs
    ):
        del unused_kwargs
        super(SpectralNormConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.coeff = coeff
        self.n_iterations = n_iterations
        self.atol = atol
        self.rtol = rtol
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
        self.initialized = False
        self.register_buffer('spatial_dims', torch.tensor([1., 1.]))
        self.register_buffer('scale', torch.tensor(0.))

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def _initialize_u_v(self):
        if self.kernel_size == (1, 1):
            self.register_buffer('u', F.normalize(self.weight.new_empty(self.out_channels).normal_(0, 1), dim=0))
            self.register_buffer('v', F.normalize(self.weight.new_empty(self.in_channels).normal_(0, 1), dim=0))
        else:
            c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
            with torch.no_grad():
                num_input_dim = c * h * w
                v = F.normalize(torch.randn(num_input_dim).to(self.weight), dim=0, eps=1e-12)
                # forward call to infer the shape
                u = F.conv2d(v.view(1, c, h, w), self.weight, stride=self.stride, padding=self.padding, bias=None)
                num_output_dim = u.shape[0] * u.shape[1] * u.shape[2] * u.shape[3]
                self.out_shape = u.shape
                # overwrite u with random init
                u = F.normalize(torch.randn(num_output_dim).to(self.weight), dim=0, eps=1e-12)

                self.register_buffer('u', u)
                self.register_buffer('v', v)

    def compute_weight(self, update=True, n_iterations=None):
        if not self.initialized:
            self._initialize_u_v()
            self.initialized = True

        if self.kernel_size == (1, 1):
            return self._compute_weight_1x1(update, n_iterations)
        else:
            return self._compute_weight_kxk(update, n_iterations)

    def _compute_weight_1x1(self, update=True, n_iterations=None, atol=None, rtol=None):
        n_iterations = self.n_iterations if n_iterations is None else n_iterations
        atol = self.atol if atol is None else atol
        rtol = self.rtol if rtol is None else atol

        if n_iterations is None and (atol is None or rtol is None):
            raise ValueError('Need one of n_iteration or (atol, rtol).')

        if n_iterations is None:
            n_iterations = 20000

        u = self.u
        v = self.v
        weight = self.weight.view(self.out_channels, self.in_channels)
        if update:
            with torch.no_grad():
                itrs_used = 0
                for _ in range(n_iterations):
                    old_v = v.clone()
                    old_u = u.clone()
                    # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                    # are the first left and right singular vectors.
                    # This power iteration produces approximations of `u` and `v`.
                    v = F.normalize(torch.mv(weight.t(), u), dim=0, out=v)
                    u = F.normalize(torch.mv(weight, v), dim=0, out=u)
                    itrs_used = itrs_used + 1
                    if atol is not None and rtol is not None:
                        err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
                        err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
                        tol_u = atol + rtol * torch.max(u)
                        tol_v = atol + rtol * torch.max(v)
                        if err_u < tol_u and err_v < tol_v:
                            break
                if itrs_used > 0:
                    u = u.clone()
                    v = v.clone()

        sigma = torch.dot(u, torch.mv(weight, v))
        with torch.no_grad():
            self.scale.copy_(sigma)
        # soft normalization: only when sigma larger than coeff
        factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
        weight = weight / factor
        return weight.view(self.out_channels, self.in_channels, 1, 1)

    def _compute_weight_kxk(self, update=True, n_iterations=None, atol=None, rtol=None):
        n_iterations = self.n_iterations if n_iterations is None else n_iterations
        atol = self.atol if atol is None else atol
        rtol = self.rtol if rtol is None else atol

        if n_iterations is None and (atol is None or rtol is None):
            raise ValueError('Need one of n_iteration or (atol, rtol).')

        if n_iterations is None:
            n_iterations = 20000

        u = self.u
        v = self.v
        weight = self.weight
        c, h, w = self.in_channels, int(self.spatial_dims[0].item()), int(self.spatial_dims[1].item())
        if update:
            with torch.no_grad():
                itrs_used = 0
                for _ in range(n_iterations):
                    old_u = u.clone()
                    old_v = v.clone()
                    v_s = F.conv_transpose2d(
                        u.view(self.out_shape), weight, stride=self.stride, padding=self.padding, output_padding=0
                    )
                    v = F.normalize(v_s.view(-1), dim=0, out=v)
                    u_s = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
                    u = F.normalize(u_s.view(-1), dim=0, out=u)
                    itrs_used = itrs_used + 1
                    if atol is not None and rtol is not None:
                        err_u = torch.norm(u - old_u) / (u.nelement()**0.5)
                        err_v = torch.norm(v - old_v) / (v.nelement()**0.5)
                        tol_u = atol + rtol * torch.max(u)
                        tol_v = atol + rtol * torch.max(v)
                        if err_u < tol_u and err_v < tol_v:
                            break
                if itrs_used > 0:
                    u = u.clone()
                    v = v.clone()

        weight_v = F.conv2d(v.view(1, c, h, w), weight, stride=self.stride, padding=self.padding, bias=None)
        weight_v = weight_v.view(-1)
        sigma = torch.dot(u.view(-1), weight_v)
        with torch.no_grad():
            self.scale.copy_(sigma)
        # soft normalization: only when sigma larger than coeff
        factor = torch.max(torch.ones(1).to(weight.device), sigma / self.coeff)
        weight = weight / factor
        return weight

    def forward(self, input):
        if not self.initialized: self.spatial_dims.copy_(torch.tensor(input.shape[2:4]).to(self.spatial_dims))
        weight = self.compute_weight(update=self.training)
        return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.bias is None:
            s += ', bias=False'
        s += ', coeff={}, n_iters={}, atol={}, rtol={}'.format(self.coeff, self.n_iterations, self.atol, self.rtol)
        return s.format(**self.__dict__)


class LopLinear(nn.Linear):
    """Lipschitz constant defined using operator norms."""

    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        coeff=0.97,
        domain=float('inf'),
        codomain=float('inf'),
        local_constraint=True,
        **unused_kwargs,
    ):
        del unused_kwargs
        super(LopLinear, self).__init__(in_features, out_features, bias)
        self.coeff = coeff
        self.domain = domain
        self.codomain = codomain
        self.local_constraint = local_constraint
        max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
        self.max_across_dim = 1 if max_across_input_dims else 0
        self.register_buffer('scale', torch.tensor(0.))

    def compute_weight(self):
        scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
        if not self.local_constraint: scale = scale.max()
        with torch.no_grad():
            self.scale.copy_(scale.max())

        # soft normalization
        factor = torch.max(torch.ones(1).to(self.weight), scale / self.coeff)

        return self.weight / factor

    def forward(self, input):
        weight = self.compute_weight()
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        s = super(LopLinear, self).extra_repr()
        return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
            self.coeff, self.domain, self.codomain, self.local_constraint
        )


class LopConv2d(nn.Conv2d):
    """Lipschitz constant defined using operator norms."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        bias=True,
        coeff=0.97,
        domain=float('inf'),
        codomain=float('inf'),
        local_constraint=True,
        **unused_kwargs,
    ):
        del unused_kwargs
        super(LopConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias)
        self.coeff = coeff
        self.domain = domain
        self.codomain = codomain
        self.local_constraint = local_constraint
        max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
        self.max_across_dim = 1 if max_across_input_dims else 0
        self.register_buffer('scale', torch.tensor(0.))

    def compute_weight(self):
        scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
        if not self.local_constraint: scale = scale.max()
        with torch.no_grad():
            self.scale.copy_(scale.max())

        # soft normalization
        factor = torch.max(torch.ones(1).to(self.weight.device), scale / self.coeff)

        return self.weight / factor

    def forward(self, input):
        weight = self.compute_weight()
        return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)

    def extra_repr(self):
        s = super(LopConv2d, self).extra_repr()
        return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
            self.coeff, self.domain, self.codomain, self.local_constraint
        )


class LipNormLinear(nn.Linear):
    """Lipschitz constant defined using operator norms."""

    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        coeff=0.97,
        domain=float('inf'),
        codomain=float('inf'),
        local_constraint=True,
        **unused_kwargs,
    ):
        del unused_kwargs
        super(LipNormLinear, self).__init__(in_features, out_features, bias)
        self.coeff = coeff
        self.domain = domain
        self.codomain = codomain
        self.local_constraint = local_constraint
        max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
        self.max_across_dim = 1 if max_across_input_dims else 0

        # Initialize scale parameter.
        with torch.no_grad():
            w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
            if not self.local_constraint: w_scale = w_scale.max()
            self.scale = nn.Parameter(_logit(w_scale / self.coeff))

    def compute_weight(self):
        w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
        if not self.local_constraint: w_scale = w_scale.max()
        return self.weight / w_scale * torch.sigmoid(self.scale) * self.coeff

    def forward(self, input):
        weight = self.compute_weight()
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        s = super(LipNormLinear, self).extra_repr()
        return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
            self.coeff, self.domain, self.codomain, self.local_constraint
        )


class LipNormConv2d(nn.Conv2d):
    """Lipschitz constant defined using operator norms."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        bias=True,
        coeff=0.97,
        domain=float('inf'),
        codomain=float('inf'),
        local_constraint=True,
        **unused_kwargs,
    ):
        del unused_kwargs
        super(LipNormConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, bias)
        self.coeff = coeff
        self.domain = domain
        self.codomain = codomain
        self.local_constraint = local_constraint
        max_across_input_dims, self.norm_type = operator_norm_settings(self.domain, self.codomain)
        self.max_across_dim = 1 if max_across_input_dims else 0

        # Initialize scale parameter.
        with torch.no_grad():
            w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
            if not self.local_constraint: w_scale = w_scale.max()
            self.scale = nn.Parameter(_logit(w_scale / self.coeff))

    def compute_weight(self):
        w_scale = _norm_except_dim(self.weight, self.norm_type, dim=self.max_across_dim)
        if not self.local_constraint: w_scale = w_scale.max()
        return self.weight / w_scale * torch.sigmoid(self.scale)

    def forward(self, input):
        weight = self.compute_weight()
        return F.conv2d(input, weight, self.bias, self.stride, self.padding, 1, 1)

    def extra_repr(self):
        s = super(LipNormConv2d, self).extra_repr()
        return s + ', coeff={}, domain={}, codomain={}, local={}'.format(
            self.coeff, self.domain, self.codomain, self.local_constraint
        )


def _logit(p):
    p = torch.max(torch.ones(1) * 0.1, torch.min(torch.ones(1) * 0.9, p))
    return torch.log(p + 1e-10) + torch.log(1 - p + 1e-10)


def _norm_except_dim(w, norm_type, dim):
    if norm_type == 1 or norm_type == 2:
        return torch.norm_except_dim(w, norm_type, dim)
    elif norm_type == float('inf'):
        return _max_except_dim(w, dim)


def _max_except_dim(input, dim):
    maxed = input
    for axis in range(input.ndimension() - 1, dim, -1):
        maxed, _ = maxed.max(axis, keepdim=True)
    for axis in range(dim - 1, -1, -1):
        maxed, _ = maxed.max(axis, keepdim=True)
    return maxed


def operator_norm_settings(domain, codomain):
    if domain == 1 and codomain == 1:
        # maximum l1-norm of column
        max_across_input_dims = True
        norm_type = 1
    elif domain == 1 and codomain == 2:
        # maximum l2-norm of column
        max_across_input_dims = True
        norm_type = 2
    elif domain == 1 and codomain == float("inf"):
        # maximum l-inf norm of column
        max_across_input_dims = True
        norm_type = float("inf")
    elif domain == 2 and codomain == float("inf"):
        # maximum l2-norm of row
        max_across_input_dims = False
        norm_type = 2
    elif domain == float("inf") and codomain == float("inf"):
        # maximum l1-norm of row
        max_across_input_dims = False
        norm_type = 1
    else:
        raise ValueError('Unknown combination of domain "{}" and codomain "{}"'.format(domain, codomain))

    return max_across_input_dims, norm_type


def get_linear(in_features, out_features, bias=True, coeff=0.97, domain=None, codomain=None, **kwargs):
    _linear = InducedNormLinear
    if domain == 1:
        if codomain in [1, 2, float('inf')]:
            _linear = LopLinear
    elif codomain == float('inf'):
        if domain in [2, float('inf')]:
            _linear = LopLinear
    return _linear(in_features, out_features, bias, coeff, domain, codomain, **kwargs)


def get_conv2d(
    in_channels, out_channels, kernel_size, stride, padding, bias=True, coeff=0.97, domain=None, codomain=None, **kwargs
):
    _conv2d = InducedNormConv2d
    if domain == 1:
        if codomain in [1, 2, float('inf')]:
            _conv2d = LopConv2d
    elif codomain == float('inf'):
        if domain in [2, float('inf')]:
            _conv2d = LopConv2d
    return _conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, coeff, domain, codomain, **kwargs)