import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.autograd import Function
import bn_lib

class BN2dFunc(Function):
    def __init__(self, running_mean, running_var, training, momentum, eps):
        self.running_mean = running_mean
        self.running_var = running_var
        self.training = training
        self.momentum = momentum
        self.eps = eps

    def forward(self, input, weight, bias):
        nB = input.size(0)
        nC = input.size(1)
        nH = input.size(2)
        nW = input.size(3)

        output = input.new(nB, nC, nH, nW) 
        self.input = input
        self.weight = weight
        self.bias = bias
        self.x = input.new(nB, nC, nH, nW) 
        self.x_norm = input.new(nB, nC, nH, nW) 
        self.mean = input.new(nB, nC) 
        self.var = input.new(nB, nC) 

        if input.is_cuda:
            bn_lib.bn_forward_gpu(input, self.x, self.x_norm, self.mean, self.running_mean, self.var, self.running_var, weight, bias, self.training, output)
        else:
            bn_lib.bn_forward(input, self.x, self.x_norm, self.mean, self.running_mean, self.var, self.running_var, weight, bias, self.training, output)
        return output

    def backward(self, grad_output):
        nB = grad_output.size(0)
        nC = grad_output.size(1)
        nH = grad_output.size(2)
        nW = grad_output.size(3)
        grad_input = grad_output.new(nB, nC, nH, nW) 
        grad_mean = grad_output.new(nC) 
        grad_var = grad_output.new(nC) 
        grad_weight = grad_output.new(nC) 
        grad_bias = grad_output.new(nC) 
        
        if grad_output.is_cuda:
            bn_lib.bn_backward_gpu(grad_output, self.input, self.x_norm, self.mean, grad_mean, self.var, grad_var, self.weight, grad_weight, self.bias, grad_bias, self.training, grad_input)
        else:
            bn_lib.bn_backward(grad_output, self.input, self.x_norm, self.mean, grad_mean, self.var, grad_var, self.weight, grad_weight, self.bias, grad_bias, self.training, grad_input)
        
        return grad_input, grad_weight, grad_bias  

class BN2d(nn.Module):
    def __init__(self, num_features, momentum=0.01, eps=1e-5):
        super(BN2d, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))
        self.momentum = momentum
        self.eps = eps

        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.weight.data.uniform_()
        self.bias.data.zero_()

    def forward(self, input):
        #print('------------ BN2d input -------------')
        #print(input.data.storage()[0:10])
        return BN2dFunc(self.running_mean, self.running_var, self.training, self.momentum, self.eps)(input, self.weight, self.bias)

class BN2d_slow(nn.Module):
    def __init__(self, num_features, momentum=0.01):
        super(BN2d_slow, self).__init__()
        self.num_features = num_features
        self.weight = Parameter(torch.Tensor(num_features))
        self.bias = Parameter(torch.Tensor(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))
        self.eps = 1e-5
        self.momentum = momentum

        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.weight.data.uniform_()
        self.bias.data.zero_()
    def forward(self, x): 
        nB = x.data.size(0)
        nC = x.data.size(1)
        nH = x.data.size(2)
        nW = x.data.size(3)
        samples = nB*nH*nW
        y = x.view(nB, nC, nH*nW).transpose(1,2).contiguous().view(-1,nC)
        if self.training:
            print('forward in training mode on autograd')
            m = Variable(y.mean(0).data, requires_grad=False)
            v = Variable(y.var(0).data, requires_grad=False)
            self.running_mean = (1-self.momentum)*self.running_mean + self.momentum * m.data.view(-1)
            self.running_var = (1-self.momentum)*self.running_var + self.momentum * v.data.view(-1)
            m = m.repeat(samples, 1)
            v = v.repeat(samples, 1)*(samples-1.0)/samples
        else:
            m = Variable(self.running_mean.repeat(samples, 1), requires_grad=False)
            v = Variable(self.running_var.repeat(samples, 1), requires_grad=False)
        w = self.weight.repeat(samples, 1)
        b = self.bias.repeat(samples, 1)
        y = (y - m)/(v+self.eps).sqrt() * w + b 
        y = y.view(nB, nH*nW, nC).transpose(1,2).contiguous().view(nB, nC, nH, nW) 
        return y


if __name__ == '__main__':
    nB = 64
    nC = 3
    nH = 4
    nW = 4
    samples = nB*nH*nW
    a = torch.rand(nB,nC,nH,nW)
    a = Variable(a)
    nn_model  = nn.BatchNorm2d(nC)
    dkn_model = BN2d(nC)
    atg_model = BN2d_slow(nC)

    nn_model.weight.data.fill_(1.0)
    nn_model.bias.data.zero_()
    dkn_model.weight.data.fill_(1.0)
    dkn_model.bias.data.zero_()
    atg_model.weight.data.fill_(1.0)
    atg_model.bias.data.zero_()
    nn_out_cpu = nn_model(a)
    dkn_out_cpu = dkn_model(a)
    atg_out_cpu = atg_model(a)



    a = a.cuda()
    nn_model.cuda()
    dkn_model.cuda()
    atg_model.cuda()

    nn_out_gpu = nn_model(a)
    dkn_out_gpu = dkn_model(a)
    atg_out_gpu = atg_model(a)

    print('--- nn cpu out ---')
    print(nn_out_cpu.data.storage()[0:10])
    print('--- dkn cpu out ---')
    print(dkn_out_cpu.data.storage()[0:10])
    print('--- atg cpu out ---')
    print(atg_out_cpu.data.storage()[0:10])


    print('--- nn gpu out ---')
    print(nn_out_gpu.data.storage()[0:10])
    print('--- dkn gpu out ---')
    print(dkn_out_gpu.data.storage()[0:10])
    print('--- atg gpu out ---')
    print(atg_out_gpu.data.storage()[0:10])