import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.nn import init
from torch.utils import model_zoo
from torchvision.models import vgg

from .models import register_model


def get_upsample_filter(size):
	"""Make a 2D bilinear kernel suitable for upsampling"""
	factor = (size + 1) // 2
	if size % 2 == 1:
		center = factor - 1
	else:
		center = factor - 0.5
	og = np.ogrid[:size, :size]
	filter = (1 - abs(og[0] - center) / factor) * \
	         (1 - abs(og[1] - center) / factor)
	return torch.from_numpy(filter).float()


class Bilinear(nn.Module):
	
	def __init__(self, factor, num_channels):
		super().__init__()
		self.factor = factor
		filter = get_upsample_filter(factor * 2)
		w = torch.zeros(num_channels, num_channels, factor * 2, factor * 2)
		for i in range(num_channels):
			w[i, i] = filter
		self.register_buffer('w', w)
	
	def forward(self, x):
		return F.conv_transpose2d(x, Variable(self.w), stride=self.factor)


@register_model('fcn8s')
class VGG16_FCN8s(nn.Module):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize(
			mean=[0.485, 0.456, 0.406],
			std=[0.229, 0.224, 0.225]),
	])
	
	def __init__(self, num_cls=19, pretrained=True, weights_init=None,
	             output_last_ft=False):
		super().__init__()
		self.output_last_ft = output_last_ft
		if weights_init:
			batch_norm = False
		else:
			batch_norm = True
		self.vgg = make_layers(vgg.cfg['D'], batch_norm=False)
		self.vgg_head = nn.Sequential(
			nn.Conv2d(512, 4096, 7),
			nn.ReLU(inplace=True),
			nn.Dropout2d(p=0.5),
			nn.Conv2d(4096, 4096, 1),
			nn.ReLU(inplace=True),
			nn.Dropout2d(p=0.5),
			nn.Conv2d(4096, num_cls, 1)
		)
		self.upscore2 = self.upscore_pool4 = Bilinear(2, num_cls)
		self.upscore8 = Bilinear(8, num_cls)
		self.score_pool4 = nn.Conv2d(512, num_cls, 1)
		for param in self.score_pool4.parameters():
			# init.constant(param, 0)
			init.constant_(param, 0)
		self.score_pool3 = nn.Conv2d(256, num_cls, 1)
		for param in self.score_pool3.parameters():
			# init.constant(param, 0)
			init.constant_(param, 0)
		
		if pretrained:
			if weights_init is not None:
				self.load_weights(torch.load(weights_init))
			else:
				self.load_base_weights()
	
	def load_base_vgg(self, weights_state_dict):
		vgg_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg.')
		self.vgg.load_state_dict(vgg_state_dict)
	
	def load_vgg_head(self, weights_state_dict):
		vgg_head_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg_head.')
		self.vgg_head.load_state_dict(vgg_head_state_dict)
	
	def get_dict_by_prefix(self, weights_state_dict, prefix):
		return {k[len(prefix):]: v
		        for k, v in weights_state_dict.items()
		        if k.startswith(prefix)}
	
	def load_weights(self, weights_state_dict):
		self.load_base_vgg(weights_state_dict)
		self.load_vgg_head(weights_state_dict)
	
	def split_vgg_head(self):
		self.classifier = list(self.vgg_head.children())[-1]
		self.vgg_head_feat = nn.Sequential(*list(self.vgg_head.children())[:-1])
	
	def forward(self, x):
		input = x
		x = F.pad(x, (99, 99, 99, 99), mode='constant', value=0)
		intermediates = {}
		fts_to_save = {16: 'pool3', 23: 'pool4'}
		for i, module in enumerate(self.vgg):
			x = module(x)
			if i in fts_to_save:
				intermediates[fts_to_save[i]] = x
		
		ft_to_save = 5  # Dropout before classifier
		last_ft = {}
		for i, module in enumerate(self.vgg_head):
			x = module(x)
			if i == ft_to_save:
				last_ft = x
		
		_, _, h, w = x.size()
		upscore2 = self.upscore2(x)
		pool4 = intermediates['pool4']
		score_pool4 = self.score_pool4(0.01 * pool4)
		score_pool4c = _crop(score_pool4, upscore2, offset=5)
		fuse_pool4 = upscore2 + score_pool4c
		upscore_pool4 = self.upscore_pool4(fuse_pool4)
		pool3 = intermediates['pool3']
		score_pool3 = self.score_pool3(0.0001 * pool3)
		score_pool3c = _crop(score_pool3, upscore_pool4, offset=9)
		fuse_pool3 = upscore_pool4 + score_pool3c
		upscore8 = self.upscore8(fuse_pool3)
		score = _crop(upscore8, input, offset=31)
		if self.output_last_ft:
			return score, last_ft
		else:
			return score
	
	def load_base_weights(self):
		"""This is complicated because we converted the base model to be fully
		convolutional, so some surgery needs to happen here."""
		base_state_dict = model_zoo.load_url(vgg.model_urls['vgg16'])
		vgg_state_dict = {k[len('features.'):]: v
		                  for k, v in base_state_dict.items()
		                  if k.startswith('features.')}
		self.vgg.load_state_dict(vgg_state_dict)
		vgg_head_params = self.vgg_head.parameters()
		for k, v in base_state_dict.items():
			if not k.startswith('classifier.'):
				continue
			if k.startswith('classifier.6.'):
				# skip final classifier output
				continue
			vgg_head_param = next(vgg_head_params)
			vgg_head_param.data = v.view(vgg_head_param.size())


class VGG16_FCN8s_caffe(VGG16_FCN8s):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize(
			mean=[0.485, 0.458, 0.408],
			std=[0.00392156862745098] * 3),
		torchvision.transforms.Lambda(
			lambda x: torch.stack(torch.unbind(x, 1)[::-1], 1))
	])
	
	def load_base_weights(self):
		base_state_dict = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth')
		vgg_state_dict = {k[len('features.'):]: v
		                  for k, v in base_state_dict.items()
		                  if k.startswith('features.')}
		self.vgg.load_state_dict(vgg_state_dict)
		vgg_head_params = self.vgg_head.parameters()
		for k, v in base_state_dict.items():
			if not k.startswith('classifier.'):
				continue
			if k.startswith('classifier.6.'):
				# skip final classifier output
				continue
			vgg_head_param = next(vgg_head_params)
			vgg_head_param.data = v.view(vgg_head_param.size())


class Discriminator(nn.Module):
	def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''):
		super().__init__()
		dim1 = 1024 if input_dim == 4096 else 512
		dim2 = int(dim1 / 2)
		self.D = nn.Sequential(
			nn.Conv2d(input_dim, dim1, 1),
			nn.Dropout2d(p=0.5),
			nn.ReLU(inplace=True),
			nn.Conv2d(dim1, dim2, 1),
			nn.Dropout2d(p=0.5),
			nn.ReLU(inplace=True),
			nn.Conv2d(dim2, output_dim, 1)
		)
		
		if pretrained and weights_init is not None:
			self.load_weights(weights_init)
	
	def forward(self, x):
		d_score = self.D(x)
		return d_score
	
	def load_weights(self, weights):
		print('Loading discriminator weights')
		self.load_state_dict(torch.load(weights))


class Transform_Module(nn.Module):
	def __init__(self, input_dim=4096):
		super().__init__()
		self.transform = nn.Sequential(
			nn.Conv2d(input_dim, input_dim, 1),
			nn.ReLU(inplace=True),
			# nn.Conv2d(input_dim, input_dim, 1),
			# nn.ReLU(inplace=True),
		)
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				init_eye(m.weight)
				m.bias.data.zero_()
	
	def forward(self, x):
		t_x = self.transform(x)
		return t_x


def init_eye(tensor):
	if isinstance(tensor, Variable):
		init_eye(tensor.data)
		return tensor
	return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))


def _crop(input, shape, offset=0):
	_, _, h, w = shape.size()
	return input[:, :, offset:offset + h, offset:offset + w].contiguous()


def make_layers(cfg, batch_norm=False):
	"""This is almost verbatim from torchvision.models.vgg, except that the
	MaxPool2d modules are configured with ceil_mode=True.
	"""
	layers = []
	in_channels = 3
	for v in cfg:
		if v == 'M':
			layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))
		else:
			conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
			modules = [conv2d, nn.ReLU(inplace=True)]
			if batch_norm:
				modules.insert(1, nn.BatchNorm2d(v))
			layers.extend(modules)
			in_channels = v
	return nn.Sequential(*layers)