# Implementation based on original paper:
# https://github.com/pfnet-research/sngan_projection

from torch import nn
import torch.nn.functional as F
import torch


def l2normalize(v, esp=1e-8):
    return v / (v.norm() + esp)


def sn_weight(weight, u, height, n_power_iterations):
    weight.requires_grad_(False)
    for _ in range(n_power_iterations):
        v = l2normalize(torch.mv(weight.view(height, -1).t(), u))
        u = l2normalize(torch.mv(weight.view(height, -1), v))

    weight.requires_grad_(True)
    sigma = u.dot(weight.view(height, -1).mv(v))
    return torch.div(weight, sigma), u


class SNConv2d(nn.Conv2d):
    def __init__(self, *args, n_power_iterations=1, **kwargs):
        super(SNConv2d, self).__init__(*args, **kwargs)
        self.n_power_iterations = n_power_iterations
        self.height = self.weight.shape[0]
        self.register_buffer(
            'u', l2normalize(self.weight.new_empty(self.height).normal_(0, 1)))

    def forward(self, input):
        w_sn, self.u = sn_weight(self.weight, self.u, self.height,
                                 self.n_power_iterations)
        return F.conv2d(input, w_sn, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


class SNLinear(nn.Linear):
    def __init__(self, *args, n_power_iterations=1, **kwargs):
        super(SNLinear, self).__init__(*args, **kwargs)
        self.n_power_iterations = n_power_iterations
        self.height = self.weight.shape[0]
        self.register_buffer(
            'u', l2normalize(self.weight.new(self.height).normal_(0, 1)))

    def forward(self, input):
        w_sn, self.u = sn_weight(
            self.weight, self.u, self.height, self.n_power_iterations)
        return F.linear(input, w_sn, self.bias)