import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.nn import init
from torch.utils import model_zoo
from torchvision.models import vgg

from .models import register_model

def get_upsample_filter(size):
    """Make a 2D bilinear kernel suitable for upsampling"""
    factor = (size + 1) // 2
    if size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:size, :size]
    filter = (1 - abs(og[0] - center) / factor) * \
             (1 - abs(og[1] - center) / factor)
    return torch.from_numpy(filter).float()


class Bilinear(nn.Module):

    def __init__(self, factor, num_channels):
        super().__init__()
        self.factor = factor
        filter = get_upsample_filter(factor * 2)
        w = torch.zeros(num_channels, num_channels, factor * 2, factor * 2)
        for i in range(num_channels):
            w[i, i] = filter
        self.register_buffer('w', w)

    def forward(self, x):
        return F.conv_transpose2d(x, Variable(self.w), stride=self.factor)


@register_model('fcn8s')
class VGG16_FCN8s(nn.Module):

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]),
        ])

    def __init__(self, num_cls=19, pretrained=True, weights_init=None, 
            output_last_ft=False):
        super().__init__()
        self.output_last_ft = output_last_ft
        self.vgg = make_layers(vgg.cfg['D'])
        self.vgg_head = nn.Sequential(
            nn.Conv2d(512, 4096, 7),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Conv2d(4096, 4096, 1),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.5),
            nn.Conv2d(4096, num_cls, 1)
            )
        self.upscore2 = self.upscore_pool4 = Bilinear(2, num_cls)
        self.upscore8 = Bilinear(8, num_cls)
        self.score_pool4 = nn.Conv2d(512, num_cls, 1)
        for param in self.score_pool4.parameters():
            init.constant(param, 0)
        self.score_pool3 = nn.Conv2d(256, num_cls, 1)
        for param in self.score_pool3.parameters():
            init.constant(param, 0)
        
        if pretrained:
            if weights_init is not None:
                self.load_weights(torch.load(weights_init))
            else:
                self.load_base_weights()
 
    def load_base_vgg(self, weights_state_dict):
        vgg_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg.')
        self.vgg.load_state_dict(vgg_state_dict)
     
    def load_vgg_head(self, weights_state_dict):
        vgg_head_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg_head.') 
        self.vgg_head.load_state_dict(vgg_head_state_dict)
    
    def get_dict_by_prefix(self, weights_state_dict, prefix):
        return {k[len(prefix):]: v 
                for k,v in weights_state_dict.items()
                if k.startswith(prefix)}


    def load_weights(self, weights_state_dict):
        self.load_base_vgg(weights_state_dict)
        self.load_vgg_head(weights_state_dict)

    def split_vgg_head(self):
        self.classifier = list(self.vgg_head.children())[-1]
        self.vgg_head_feat = nn.Sequential(*list(self.vgg_head.children())[:-1])


    def forward(self, x):
        input = x
        x = F.pad(x, (99, 99, 99, 99), mode='constant', value=0)
        intermediates = {}
        fts_to_save = {16: 'pool3', 23: 'pool4'}
        for i, module in enumerate(self.vgg):
            x = module(x)
            if i in fts_to_save:
                intermediates[fts_to_save[i]] = x
       
        ft_to_save = 5 # Dropout before classifier
        last_ft = {}
        for i, module in enumerate(self.vgg_head):
            x = module(x)
            if i == ft_to_save:
                last_ft = x      
        
        _, _, h, w = x.size()
        upscore2 = self.upscore2(x)
        pool4 = intermediates['pool4']
        score_pool4 = self.score_pool4(0.01 * pool4)
        score_pool4c = _crop(score_pool4, upscore2, offset=5)
        fuse_pool4 = upscore2 + score_pool4c
        upscore_pool4 = self.upscore_pool4(fuse_pool4)
        pool3 = intermediates['pool3']
        score_pool3 = self.score_pool3(0.0001 * pool3)
        score_pool3c = _crop(score_pool3, upscore_pool4, offset=9)
        fuse_pool3 = upscore_pool4 + score_pool3c
        upscore8 = self.upscore8(fuse_pool3)
        score = _crop(upscore8, input, offset=31)
        if self.output_last_ft: 
            return score, last_ft
        else:
            return score


    def load_base_weights(self):
        """This is complicated because we converted the base model to be fully
        convolutional, so some surgery needs to happen here."""
        base_state_dict = model_zoo.load_url(vgg.model_urls['vgg16'])
        vgg_state_dict = {k[len('features.'):]: v
                          for k, v in base_state_dict.items()
                          if k.startswith('features.')}
        self.vgg.load_state_dict(vgg_state_dict)
        vgg_head_params = self.vgg_head.parameters()
        for k, v in base_state_dict.items():
            if not k.startswith('classifier.'):
                continue
            if k.startswith('classifier.6.'):
                # skip final classifier output
                continue
            vgg_head_param = next(vgg_head_params)
            vgg_head_param.data = v.view(vgg_head_param.size())


    


class VGG16_FCN8s_caffe(VGG16_FCN8s):

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485, 0.458, 0.408],
            std=[0.00392156862745098] * 3),
        torchvision.transforms.Lambda(
            lambda x: torch.stack(torch.unbind(x, 1)[::-1], 1))
        ])

    def load_base_weights(self):
        base_state_dict = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth')
        vgg_state_dict = {k[len('features.'):]: v
                          for k, v in base_state_dict.items()
                          if k.startswith('features.')}
        self.vgg.load_state_dict(vgg_state_dict)
        vgg_head_params = self.vgg_head.parameters()
        for k, v in base_state_dict.items():
            if not k.startswith('classifier.'):
                continue
            if k.startswith('classifier.6.'):
                # skip final classifier output
                continue
            vgg_head_param = next(vgg_head_params)
            vgg_head_param.data = v.view(vgg_head_param.size())

class Discriminator(nn.Module):
    def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''):
        super().__init__()
        dim1 = 1024 if input_dim==4096 else 512
        dim2 = int(dim1/2)
        self.D = nn.Sequential(
            nn.Conv2d(input_dim, dim1, 1),
            nn.Dropout2d(p=0.5),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim1, dim2, 1),
            nn.Dropout2d(p=0.5),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim2, output_dim, 1)
            )

        if pretrained and weights_init is not None:
            self.load_weights(weights_init) 

    def forward(self, x): 
        d_score = self.D(x)
        return d_score

    def load_weights(self, weights):
        print('Loading discriminator weights')
        self.load_state_dict(torch.load(weights))
   


class Transform_Module(nn.Module):
    def __init__(self, input_dim=4096):
        super().__init__()
        self.transform = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, 1),
            nn.ReLU(inplace=True),
            #nn.Conv2d(input_dim, input_dim, 1),
            #nn.ReLU(inplace=True),
            )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_eye(m.weight)
                m.bias.data.zero_()

    def forward(self, x): 
        t_x = self.transform(x)
        return t_x


def init_eye(tensor):
    if isinstance(tensor, Variable):
        init_eye(tensor.data)
        return tensor
    return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))


def _crop(input, shape, offset=0):
    _, _, h, w = shape.size()
    return input[:, :, offset:offset + h, offset:offset + w].contiguous()


def make_layers(cfg, batch_norm=False):
    """This is almost verbatim from torchvision.models.vgg, except that the
    MaxPool2d modules are configured with ceil_mode=True.
    """
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            modules = [conv2d, nn.ReLU(inplace=True)]
            if batch_norm:
                modules.insert(1, nn.BatchNorm2d(v))
            layers.extend(modules)
            in_channels = v
    return nn.Sequential(*layers)