import torch
import torch.nn as nn
import math
import pickle
from torch.autograd import Variable

def loss_MSE(x, y, size_average=False):
  z = x - y 
  z2 = z * z
  if size_average:
    return z2.mean()
  else:
    return z2.sum().div(x.size(0)*2)
    
def loss_Textures(x, y, nc=3, alpha=1.2, margin=0):
  xi = x.contiguous().view(x.size(0), -1, nc, x.size(2), x.size(3))
  yi = y.contiguous().view(y.size(0), -1, nc, y.size(2), y.size(3))
  
  xi2 = torch.sum(xi * xi, dim=2)
  yi2 = torch.sum(yi * yi, dim=2)
  
  out = nn.functional.relu(yi2.mul(alpha) - xi2 + margin)
  
  return torch.mean(out)
  
  
class WaveletTransform(nn.Module): 
    def __init__(self, scale=1, dec=True, params_path='wavelet_weights_c2.pkl', transpose=True):
        super(WaveletTransform, self).__init__()
        
        self.scale = scale
        self.dec = dec
        self.transpose = transpose
        
        ks = int(math.pow(2, self.scale)  )
        nc = 3 * ks * ks
        
        if dec:
          self.conv = nn.Conv2d(in_channels=3, out_channels=nc, kernel_size=ks, stride=ks, padding=0, groups=3, bias=False)
        else:
          self.conv = nn.ConvTranspose2d(in_channels=nc, out_channels=3, kernel_size=ks, stride=ks, padding=0, groups=3, bias=False)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                f = file(params_path,'rb')
                dct = pickle.load(f)
                f.close()
                m.weight.data = torch.from_numpy(dct['rec%d' % ks])
                m.weight.requires_grad = False  
                           
    def forward(self, x): 
        if self.dec:
          output = self.conv(x)          
          if self.transpose:
            osz = output.size()
            #print(osz)
            output = output.view(osz[0], 3, -1, osz[2], osz[3]).transpose(1,2).contiguous().view(osz)            
        else:
          if self.transpose:
            xx = x
            xsz = xx.size()
            xx = xx.view(xsz[0], -1, 3, xsz[2], xsz[3]).transpose(1,2).contiguous().view(xsz)             
          output = self.conv(xx)        
        return output 

class _Residual_Block(nn.Module): 
    def __init__(self, inc=64, outc=64, groups=1):
        super(_Residual_Block, self).__init__()
        
        if inc is not outc:
          self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        else:
          self.conv_expand = None
          
        self.conv1 = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)
        self.bn1 = nn.BatchNorm2d(outc)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)
        self.bn2 = nn.BatchNorm2d(outc)
        self.relu2 = nn.ReLU(inplace=True)
        
    def forward(self, x): 
        if self.conv_expand is not None:
          identity_data = self.conv_expand(x)
        else:
          identity_data = x

        output = self.relu1(self.bn1(self.conv1(x)))
        output = self.conv2(output)
        output = self.relu2(self.bn2(torch.add(output,identity_data)))
        return output 

def make_layer(block, num_of_layer, inc=64, outc=64, groups=1):
    layers = []
    layers.append(block(inc=inc, outc=outc, groups=groups))
    for _ in range(1, num_of_layer):
        layers.append(block(inc=outc, outc=outc, groups=groups))
    return nn.Sequential(*layers)   

class _Interim_Block(nn.Module): 
    def __init__(self, inc=64, outc=64, groups=1):
        super(_Interim_Block, self).__init__()
        
        self.conv_expand = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
        self.conv1 = nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=1, bias=False)
        self.bn1 = nn.BatchNorm2d(outc)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, stride=1, padding=1, groups=groups, bias=False)
        self.bn2 = nn.BatchNorm2d(outc)
        self.relu2 = nn.ReLU(inplace=True)
     
    def forward(self, x): 
        identity_data = self.conv_expand(x)          
        output = self.relu1(self.bn1(self.conv1(x)))
        output = self.conv2(output)
        output = self.relu2(self.bn2(torch.add(output,identity_data)))
        return output    
        
class NetSR(nn.Module):
    def __init__(self, scale=2, num_layers_res=2):
        super(NetSR, self).__init__()
        
        self.scale = int(scale)
        self.groups = int(math.pow(4, self.scale))
        self.wavelet_c = wavelet_c = 32
        
        #----------input conv-------------------
        self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_input = nn.BatchNorm2d(64) 
        self.relu_input = nn.ReLU(inplace=True)
        
        #----------residual-------------------
        self.residual = nn.Sequential(
            make_layer(_Residual_Block, num_layers_res, inc=64, outc=64),
            make_layer(_Residual_Block, num_layers_res, inc=64, outc=128),
            make_layer(_Residual_Block, num_layers_res, inc=128, outc=256),
            make_layer(_Residual_Block, num_layers_res, inc=256, outc=512),
            make_layer(_Residual_Block, num_layers_res, inc=512, outc=1024)
        )       
                      
        #----------wavelet conv-------------------
        inc = 1024       
        layer_num = 1
        if self.scale >= 0:
          g = 1
          self.interim_0 = _Interim_Block(inc, wavelet_c * g, g)
          self.wavelet_0 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g)
          self.predict_0 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 
                                  groups=g, bias=True)
        
        if self.scale >= 1:
          g = 3
          self.interim_1 = _Interim_Block(inc, wavelet_c * g, g)
          self.wavelet_1 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g)
          self.predict_1 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 
                                  groups=g, bias=True)
          
        if self.scale >= 2:
          g = 12
          self.interim_2 = _Interim_Block(inc, wavelet_c * g, g)
          self.wavelet_2 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g)
          self.predict_2 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 
                                  groups=g, bias=True)
        
        if self.scale >= 3:
          g = 48
          self.interim_3 = _Interim_Block(inc, wavelet_c * g, g)
          self.wavelet_3 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g)
          self.predict_3 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 
                                  groups=g, bias=True)
        
        if self.scale >= 4:
          g = 192
          self.interim_4 = _Interim_Block(inc, wavelet_c * g, g)
          self.wavelet_4 = make_layer(_Residual_Block, layer_num, wavelet_c * g, wavelet_c * 2 * g, g)
          self.predict_4 = nn.Conv2d(in_channels=wavelet_c * 2 * g, out_channels=3 * g, kernel_size=3, stride=1, padding=1, 
                                  groups=g, bias=True)
          
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
        
    def forward(self, x):
        
        f = self.relu_input(self.bn_input(self.conv_input(x)))
        
        f = self.residual(f)
        
        if self.scale >= 0:
          out_0 = self.interim_0(f)
          out_0 = self.wavelet_0(out_0)
          out_0 = self.predict_0(out_0)
          out = out_0
        
        if self.scale >= 1:
          out_1 = self.interim_1(f)
          out_1 = self.wavelet_1(out_1)
          out_1 = self.predict_1(out_1)
          out = torch.cat((out, out_1), 1)
        
        
        if self.scale >= 2:
          out_2 = self.interim_2(f)
          out_2 = self.wavelet_2(out_2)
          out_2 = self.predict_2(out_2)
          out = torch.cat((out, out_2), 1)
          
        if self.scale >= 3:
          out_3 = self.interim_3(f)
          out_3 = self.wavelet_3(out_3)
          out_3 = self.predict_3(out_3)
          out = torch.cat((out, out_3), 1)
        
        if self.scale >= 4:
          out_4 = self.interim_4(f)
          out_4 = self.wavelet_4(out_4)
          out_4 = self.predict_4(out_4)
          out = torch.cat((out, out_4), 1)

        return out