# -*- coding: utf-8 -*-
"""
@author : Haoran You

"""
import numpy as np
import torch
import os
from collections import OrderedDict
from torch.autograd import Variable
# from torch.optim import lr_scheduler
import itertools
import util.util as util
from util.image_pool import ImagePool
from . import networks
import sys
import random
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


class CycleGAN():
    def name(self):
        return 'Bayesian CycleGAN Model'

    def initialize(self, opt):
        self.opt = opt
        self.isTrain = opt.isTrain
        if torch.cuda.is_available():
            print('cuda is available, we will use gpu!')
            self.Tensor = torch.cuda.FloatTensor
            torch.cuda.manual_seed_all(100)
        else:
            self.Tensor = torch.FloatTensor
            torch.manual_seed(100)
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        # get radio for network initialization
        ratio = 256 * 256 / opt.loadSize / (opt.loadSize / opt.ratio)

        # load network
        netG_input_nc = opt.input_nc + 1
        netG_output_nc = opt.output_nc + 1
        self.netG_A = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG_A,
                                        opt.n_downsample_global, opt.n_blocks_global, opt.norm).type(self.Tensor)
        self.netG_B = networks.define_G(netG_output_nc, opt.input_nc, opt.ngf, opt.netG_B,
                                        opt.n_downsample_global, opt.n_blocks_global, opt.norm).type(self.Tensor)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.num_D_A).type(self.Tensor)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.num_D_B).type(self.Tensor)

        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG_A, 'G_A', opt.which_epoch, self.save_dir)
            self.load_network(self.netG_B, 'G_B', opt.which_epoch, self.save_dir)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', opt.which_epoch, self.save_dir)
                self.load_network(self.netD_B, 'D_B', opt.which_epoch, self.save_dir)

        # set loss functions and optimizers
        if self.isTrain:
            self.old_lr = opt.lr
            # define loss function
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionL1 = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        print('----------Network initialized!-----------')
        self.print_network(self.netG_A)
        self.print_network(self.netG_B)
        if self.isTrain:
            self.print_network(self.netD_A)
            self.print_network(self.netD_B)
        print('-----------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.input_A = input['A' if AtoB else 'B']
        self.input_B = input['B' if AtoB else 'A']
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A).type(self.Tensor)
        self.real_B = Variable(self.input_B).type(self.Tensor)

        # combine input image with random noise z
        self.real_B_zx = []
        for i in range(0, self.opt.mc_x):
            self.z_x = self.get_z_random(self.real_B[0, 1].size(), 'gauss')
            self.z_x = torch.unsqueeze(self.z_x, 0)
            self.z_x = torch.unsqueeze(self.z_x, 0)
            real_B_zx = []
            for i in range(0, self.opt.batchSize):
                _real = torch.unsqueeze(self.real_B[i], 0)
                _real = torch.cat([_real, self.z_x], dim=1)
                real_B_zx.append(_real)
            real_B_zx = torch.cat(real_B_zx)
            self.real_B_zx.append(real_B_zx)

        self.real_A_zy = []
        for i in range(0, self.opt.mc_y):
            self.z_y = self.get_z_random(self.real_A[0, 1].size(), 'gauss')
            self.z_y = torch.unsqueeze(self.z_y, 0)
            self.z_y = torch.unsqueeze(self.z_y, 0)
            real_A_zy = []
            for i in range(0, self.opt.batchSize):
                _real = torch.unsqueeze(self.real_A[i], 0)
                _real = torch.cat([_real, self.z_y], dim=1)
                real_A_zy.append(_real)
            real_A_zy = torch.cat(real_A_zy)
            self.real_A_zy.append(real_A_zy)

    def inference(self):
        real_A = Variable(self.input_A).type(self.Tensor)
        real_B = Variable(self.input_B).type(self.Tensor)

        # combine input image with random noise z
        real_B_zx = []
        z_x = self.get_z_random(real_B[0, 1].size(), 'gauss')
        z_x = torch.unsqueeze(z_x, 0)
        z_x = torch.unsqueeze(z_x, 0)
        for i in range(0, self.opt.batchSize):
            _real = torch.cat((real_B[i:i + 1], z_x), dim=1)
            real_B_zx.append(_real)
        real_B_zx = torch.cat(real_B_zx)

        real_A_zy = []
        z_y = self.get_z_random(real_A[0, 1].size(), 'gauss')
        z_y = torch.unsqueeze(z_y, 0)
        z_y = torch.unsqueeze(z_y, 0)
        for i in range(0, self.opt.batchSize):
            _real = torch.cat((real_A[i:i + 1], z_y), dim=1)
            real_A_zy.append(_real)
        real_A_zy = torch.cat(real_A_zy)

        # inference
        fake_B = self.netG_A(real_A_zy)
        fake_B_next = torch.cat((fake_B, z_x), dim=1)
        self.rec_A = self.netG_B(fake_B_next).data
        self.fake_B = fake_B.data

        fake_A = self.netG_B(real_B_zx)
        fake_A_next = torch.cat((fake_A, z_y), dim=1)
        self.rec_B = self.netG_A(fake_A_next).data
        self.fake_A = fake_A.data

    def get_image_paths(self):
        return self.image_paths

    def img_resize(self, img, target_width):
        ow, oh = img.size
        if (ow == target_width):
            return img
        else:
            w = target_width
            h = int(target_width * oh / ow)
        return img.resize((w, h), Image.BICUBIC)

    def get_z_random(self, nz, random_type='gauss'):
        z = self.Tensor(nz)
        if random_type == 'uni':
            z.copy_(torch.rand(nz) * 2.0 - 1.0)
        elif random_type == 'gauss':
            z.copy_(torch.randn(nz))
        z = Variable(z)
        return z

    def backward_G(self):
        # GAN loss D_A(G_A(A))
        fake_B = []
        for real_A in self.real_A_zy:
            _fake = self.netG_A(real_A)
            fake_B.append(_fake)
        fake_B = torch.cat(fake_B)

        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = []
        for real_B in self.real_B_zx:
            _fake = self.netG_B(real_B)
            fake_A.append(_fake)
        fake_A = torch.cat(fake_A)

        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # cycle loss
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # Forward cycle loss
        fake_B_next = []
        for i in range(0, self.opt.mc_y):
            _fake = fake_B[i * self.opt.batchSize:(i + 1) * self.opt.batchSize]
            _fake = torch.cat((_fake, self.z_x), dim=1)
            fake_B_next.append(_fake)
        fake_B_next = torch.cat(fake_B_next)

        rec_A = self.netG_B(fake_B_next)
        loss_cycle_A = 0
        for i in range(0, self.opt.mc_y):
            loss_cycle_A += self.criterionCycle(rec_A[i * self.opt.batchSize:(i + 1) * self.opt.batchSize],
                                                self.real_A) * lambda_A
        pred_cycle_G_A = self.netD_B(rec_A)
        loss_cycle_G_A = self.criterionGAN(pred_cycle_G_A, True)

        # Backward cycle loss
        fake_A_next = []
        for i in range(0, self.opt.mc_x):
            _fake = fake_A[i * self.opt.batchSize:(i + 1) * self.opt.batchSize]
            _fake = torch.cat((_fake, self.z_y), dim=1)
            fake_A_next.append(_fake)
        fake_A_next = torch.cat(fake_A_next)

        rec_B = self.netG_A(fake_A_next)
        loss_cycle_B = 0
        for i in range(0, self.opt.mc_x):
            loss_cycle_B += self.criterionCycle(rec_B[i * self.opt.batchSize:(i + 1) * self.opt.batchSize],
                                                self.real_B) * lambda_B
        pred_cycle_G_B = self.netD_A(rec_B)
        loss_cycle_G_B = self.criterionGAN(pred_cycle_G_B, True)

        # prior loss
        prior_loss_G_A = self.get_prior(self.netG_A.parameters(), self.opt.batchSize)
        prior_loss_G_B = self.get_prior(self.netG_B.parameters(), self.opt.batchSize)

        # total loss
        loss_G = loss_G_A + loss_G_B + (prior_loss_G_A + prior_loss_G_B) + (loss_cycle_G_A + loss_cycle_G_B) * self.opt.gamma + (loss_cycle_A + loss_cycle_B)
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0] + loss_cycle_G_A.data[0] * self.opt.gamma + prior_loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0] + loss_cycle_G_B.data[0] * self.opt.gamma + prior_loss_G_A.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def backward_D_A(self):
        fake_B = Variable(self.fake_B).type(self.Tensor)
        rec_B = Variable(self.rec_B).type(self.Tensor)
        # how well it classifiers fake images
        pred_fake = self.netD_A(fake_B.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        pred_cycle_fake = self.netD_A(rec_B.detach())
        loss_D_cycle_fake = self.criterionGAN(pred_cycle_fake, False)

        # how well it classifiers real images
        pred_real = self.netD_A(self.real_B)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_y

        # prior loss
        prior_loss_D_A = self.get_prior(self.netD_A.parameters(), self.opt.batchSize)

        # total loss
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5 + (loss_D_real + loss_D_cycle_fake) * 0.5 * self.opt.gamma + prior_loss_D_A
        loss_D_A.backward()
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = Variable(self.fake_A).type(self.Tensor)
        rec_A = Variable(self.rec_A).type(self.Tensor)
        # how well it classifiers fake images
        pred_fake = self.netD_B(fake_A.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        pred_cycle_fake = self.netD_B(rec_A.detach())
        loss_D_cycle_fake = self.criterionGAN(pred_cycle_fake, False)

        # how well it classifiers real images
        pred_real = self.netD_B(self.real_A)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_x

        # prior loss
        prior_loss_D_B = self.get_prior(self.netD_B.parameters(), self.opt.batchSize)

        # total loss
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5 + (loss_D_real + loss_D_cycle_fake) * 0.5 * self.opt.gamma + prior_loss_D_B
        loss_D_B.backward()
        self.loss_D_B = loss_D_B.data[0]

    def backward_G_pair(self):
        # GAN loss D_A(G_A(A)) and L1 loss
        fake_B = []
        loss_G_A_L1 = 0
        for real_A in self.real_A_zy:
            _fake = self.netG_A(real_A)
            loss_G_A_L1 += self.criterionL1(_fake, self.real_B) * self.opt.lambda_A
            fake_B.append(_fake)
        fake_B = torch.cat(fake_B)

        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = []
        loss_G_B_L1 = 0
        for real_B in self.real_B_zx:
            _fake = self.netG_B(real_B)
            loss_G_B_L1 += self.criterionL1(_fake, self.real_A) * self.opt.lambda_B
            fake_A.append(_fake)
        fake_A = torch.cat(fake_A)

        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # prior loss
        prior_loss_G_A = self.get_prior(self.netG_A.parameters(), self.opt.batchSize)
        prior_loss_G_B = self.get_prior(self.netG_B.parameters(), self.opt.batchSize)

        # total loss
        loss_G = loss_G_A + loss_G_B + (prior_loss_G_A + prior_loss_G_B) + (loss_G_A_L1 + loss_G_B_L1)
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data

    def backward_D_A_pair(self):
        fake_B = Variable(self.fake_B).type(self.Tensor)
        # how well it classifiers fake images
        pred_fake = self.netD_A(fake_B.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)

        # how well it classifiers real images
        pred_real = self.netD_A(self.real_B)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_y

        # prior loss
        prior_loss_D_A = self.get_prior(self.netD_A.parameters(), self.opt.batchSize)

        # total loss
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5 + prior_loss_D_A
        loss_D_A.backward()

    def backward_D_B_pair(self):
        fake_A = Variable(self.fake_A).type(self.Tensor)
        # how well it classifiers fake images
        pred_fake = self.netD_B(fake_A.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)

        # how well it classifiers real images
        pred_real = self.netD_B(self.real_A)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_x

        # prior loss
        prior_loss_D_B = self.get_prior(self.netD_B.parameters(), self.opt.batchSize)

        # total loss
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5 + prior_loss_D_B
        loss_D_B.backward()

    def optimize(self, pair=False):
        # forward
        self.forward()
        # G_A and G_B
        # E_A and E_B
        self.optimizer_G.zero_grad()
        if pair == True:
            self.backward_G_pair()
        else:
            self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        if pair == True:
            self.backward_D_A_pair()
        else:
            self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        if pair == True:
            self.backward_D_B_pair()
        else:
            self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_loss(self):
        loss = OrderedDict([
            ('D_A', self.loss_D_A),
            ('D_B', self.loss_D_B),
            ('G_A', self.loss_G_A),
            ('G_B', self.loss_G_B)
        ])
        if self.opt.gamma == 0:
            loss['cyc_A'] = self.loss_cycle_A
            loss['cyc_B'] = self.loss_cycle_B
        elif self.opt.gamma > 0:
            loss['cyc_G_A'] = self.loss_cycle_A
            loss['cyc_G_B'] = self.loss_cycle_B
        return loss

    def get_stye_loss(self):
        loss = OrderedDict([
            ('L1_A', self.loss_G_A_L1),
            ('L1_B', self.loss_G_B_L1)
        ])
        return loss

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        fake_B = util.tensor2im(self.fake_B)
        rec_A = util.tensor2im(self.rec_A)
        real_B = util.tensor2im(self.input_B)
        fake_A = util.tensor2im(self.fake_A)
        rec_B = util.tensor2im(self.rec_B)
        visuals = OrderedDict([
            ('real_A', real_A),
            ('fake_B', fake_B),
            ('rec_A', rec_A),
            ('real_B', real_B),
            ('fake_A', fake_A),
            ('rec_B', rec_B)
        ])
        return visuals

    def get_prior(self, parameters, dataset_size):
        prior_loss = Variable(torch.zeros((1))).cuda()
        for param in parameters:
            prior_loss += torch.mean(param * param)
        return prior_loss / dataset_size

    # def get_noise(self, parameters, alpha, dataset_size):
    #     noise_loss = Variable(torch.zeros((1))).cuda()
    #     noise_std = np.sqrt(2 * alpha)
    #     for param in parameters:
    #         noise = Variable(torch.normal(std=torch.ones(param.size()))).cuda()
    #         noise_loss += torch.sum(param*noise*noise_std)
    #     return noise_loss / dataset_size

    def save_model(self, label):
        self.save_network(self.netG_A, 'G_A', label)
        self.save_network(self.netG_B, 'G_B', label)
        self.save_network(self.netD_A, 'D_A', label)
        self.save_network(self.netD_B, 'D_B', label)

    def load_network(self, network, network_label, epoch_label, save_dir=''):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        try:
            network.load_state_dict(torch.load(save_path))
        except:
            pretrained_dict = torch.load(save_path)
            model_dict = network.state_dict()
            try:
                pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                network.load_state_dict(pretrained_dict)
                print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
            except:
                print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
                if sys.version_info >= (3, 0):
                    not_initialized = set()
                else:
                    from sets import Set
                    not_initialized = Set()
                for k, v in pretrained_dict.items():
                    if v.size() == model_dict[k].size():
                        model_dict[k] = v

                for k, v in model_dict.items():
                    if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                        not_initialized.add(k.split('.')[0])
                print(sorted(not_initialized))
                network.load_state_dict(model_dict)

    def save_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            network.cuda()

    def print_network(self, net):
        num_params = 0
        for param in net.parameters():
            num_params += param.numel()
        print(net)
        print('Total number of parameters: %d' % num_params)

    # update learning rate (called once every iter)
    def update_learning_rate(self, epoch, epoch_iter, dataset_size):
        # lrd = self.opt.lr / self.opt.niter_decay
        if epoch > self.opt.niter:
            lr = self.opt.lr * np.exp(-1.0 * min(1.0, epoch_iter / float(dataset_size)))
            for param_group in self.optimizer_D_A.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_D_B.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = lr
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
            self.old_lr = lr
        else:
            lr = self.old_lr