import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair

from .distributions import VariationalPosterior, Prior



class _ConvNd(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, use_bias, args):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.use_bias = use_bias
        self.sig1 = args.sig1
        self.sig2 = args.sig2
        self.pi = args.pi
        self.rho = args.rho
        self.device = args.device


        if transposed:
            self.weight_mu = nn.Parameter(torch.Tensor(in_channels, out_channels//groups, *kernel_size).normal_(0., 0.1))
            # self.weight_mu = nn.Parameter(torch.normal(mean=0., std=0.1, size=(in_channels, out_channels//groups, *kernel_size)))
            self.weight_rho = nn.Parameter(self.rho + torch.zeros(in_channels, out_channels//groups,*kernel_size).normal_(0., 0.1))

        else:
            # self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels//groups, *kernel_size).normal_(0., 0.1))

            self.weight_mu = nn.Parameter(torch.empty((out_channels, in_channels//groups, *kernel_size),
                                     device=self.device, dtype=torch.float32).normal_(0., 0.1), requires_grad=True)
            self.weight_rho = nn.Parameter(self.rho + torch.empty((out_channels, in_channels//groups, *kernel_size),
                                        device=self.device, dtype=torch.float32).normal_(0.,0.1), requires_grad=True)

            # self.weight_mu = nn.Parameter(torch.normal(mean=0., std=0.1, size=(out_channels, in_channels//groups, *kernel_size)))
            # self.weight_rho = nn.Parameter(self.rho + torch.zeros(out_channels, in_channels//groups,*kernel_size).normal_(0., 0.1))
                
        self.weight = VariationalPosterior(self.weight_mu, self.weight_rho, self.device).to(self.device)

        
        # Bias parameters [out_channel]
        if self.use_bias:
            # self.bias_mu = nn.Parameter(torch.Tensor(self.out_channels).normal_(0., 0.1))
            # self.bias_mu = nn.Parameter(torch.zeros(self.out_channels).normal_(0., 0.1))
            # self.bias_rho = nn.Parameter(self.rho + torch.zeros(self.out_channels).normal_(0., 0.1))
            self.bias_mu = nn.Parameter(torch.empty((self.out_channels),
                                      device=self.device, dtype=torch.float32).normal_(0., 0.1),requires_grad=True)
            self.bias_rho = nn.Parameter(self.rho + nn.Parameter(torch.empty(self.out_channels,
                                      device=self.device, dtype=torch.float32).normal_(0., 0.1),requires_grad=True))

            self.bias = VariationalPosterior(self.bias_mu, self.bias_rho, self.device).to(self.device)
        else:
            self.register_parameter('bias', None)            
        
        # Prior distributions
        self.weight_prior = Prior(args).to(self.device)

        if self.use_bias:      
            self.bias_prior = Prior(args).to(self.device)

        self.log_prior = 0
        self.log_variational_posterior = 0
        
        self.mask_flag = False


class BayesianConv2D(_ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, args, stride=1, padding=0, dilation=1, groups=1, use_bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(BayesianConv2D, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, use_bias, args)



    def prune_module(self, mask):
        self.mask_flag = True 
        self.pruned_weight_mu=self.weight_mu.data.mul_(mask)
        # self.pruned_weight_rho=self.weight_rho.data.mul_(mask)
        # pruning_mask = torch.eq(mask, torch.zeros_like(mask))


    def forward(self, input, sample=False, calculate_log_probs=False):
        if self.mask_flag:
            self.weight = VariationalPosterior(self.pruned_weight_mu, self.weight_rho, self.device)
            # if self.use_bias:
            #     self.bias = VariationalPosterior(self.bias_mu, self.bias_rho)

        if self.training or sample:
            weight = self.weight.sample()
            bias = self.bias.sample() if self.use_bias else None
                
        else:
            weight = self.weight.mu
            bias = self.bias.mu if self.use_bias else None

        if self.training or calculate_log_probs:
            if self.use_bias:
                self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
                self.log_variational_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
            else:
                self.log_prior = self.weight_prior.log_prob(weight)
                self.log_variational_posterior = self.weight.log_prob(weight)
            
        else:
            self.log_prior, self.log_variational_posterior = 0, 0
        
        return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)