#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function
import argparse
import os, sys, random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from tqdm import tqdm

import utils
import models.dcgan as dcgan
import models.mlp as mlp

isDebug = True
USE_CUDA = torch.cuda.is_available()
NUM_WORKERS = 4 * 1 if USE_CUDA else 2      # num_workers = 4 * NGPUs else 2

#default parameter values
DATASET = 'cifar10'
NETG_CIFAR10 = './samples/cifar10/netG_epoch_24.pth'
NETD_CIFAR10 = './samples/cifar10/netD_epoch_24.pth'
NETG_MNIST = './samples/mnist/netG_epoch_24.pth'
NETD_MNIST = './samples/mnist/netD_epoch_24.pth'


NUM_EPOCHS = 25
BATCH_SIZE = 128
IMG_SIZE = 64
IMG_CHANNELS = 3
NGF = BATCH_SIZE * 2
NDF = BATCH_SIZE * 2
LR_D = 0.00005
LR_G = 0.00005
D_ITERS = 5             # Number of D iterations per G iteration


def getOptimizers(opt, netG, netD):
    '''
    :param opt: Options
    :return: optimizerG, optimizerD (default RMSProp or ADAM)
    '''
    if opt.adam:
        if isDebug: print("Using ADAM Optimizer")
        optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
        optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999))
    else:
        if isDebug: print("Using RMSProp Optimizer")
        optimizerD = optim.RMSprop(netD.parameters(), lr = opt.lrD)
        optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG)

    return optimizerG, optimizerD

def getNetworks(opt):
    '''
    Returns G, D
    :param opt: hyper-param options
    :return: (netG, netD)
    '''
    ngpu = int(opt.ngpu)
    nz = int(opt.nz)
    ngf = int(opt.ngf)
    ndf = int(opt.ndf)
    nc = int(opt.nc)
    n_extra_layers = int(opt.n_extra_layers)

    netG = __getGenerator(opt, ngpu, nz, ngf, ndf, nc, n_extra_layers)
    netD = __getDiscriminator(opt, ngpu, nz, ngf, ndf, nc, n_extra_layers)

    return netG, netD


def __getGenerator(opt, ngpu, nz, ngf, ndf, nc, n_extra_layers):
    if opt.noBN:
        if isDebug: print("Using No Batch Norm (DCGAN_G_nobn) for Generator")
        netG = dcgan.DCGAN_G_nobn(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)
    elif opt.mlp_G:
        if isDebug: print("Using MLP_G for Generator")
        netG = mlp.MLP_G(opt.imageSize, nz, nc, ngf, ngpu)
    else:
        if isDebug: print("Using DCGAN_G for Generator")
        netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers, bias=False)

    netG.apply(weights_init)
    if opt.netG != '': # load checkpoint if needed
        netG.load_state_dict(torch.load(opt.netG))
    print("netG:\n {0}".format(netG))

    return netG

def __getDiscriminator(opt, ngpu, nz, ngf, ndf, nc, n_extra_layers):
    if opt.mlp_D:
        if isDebug: print("Using MLP_D for Discriminator/Critic")
        netD = mlp.MLP_D(opt.imageSize, nz, nc, ndf, ngpu)
    else:
        if isDebug: print("Using DCGAN_D for Discriminator/Critic")
        netD = dcgan.DCGAN_D(opt.imageSize, nz, nc, ndf, ngpu, n_extra_layers, False)
        netD.apply(weights_init)

    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print("netD:\n {0}".format(netD))

    return netD

def __getDataSet(opt):
    if isDebug: print(f"Getting dataset: {opt.dataset} ... ")

    dataset = None
    if opt.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        traindir = os.path.join(opt.dataroot, f"{opt.dataroot}/train")
        valdir = os.path.join(opt.dataroot, f"{opt.dataroot}/val")
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = dset.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(opt.imageSize),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        dataset = dset.ImageFolder(root=opt.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Scale(opt.imageSize),
                                       transforms.CenterCrop(opt.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                   ]))
    elif opt.dataset == 'lsun':
        dataset = dset.LSUN(db_path=opt.dataroot, classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Scale(opt.imageSize),
                                transforms.CenterCrop(opt.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    elif opt.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                               transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
        # Load pre-trained state dict
        if opt.load_dict:
            opt.netD = NETD_CIFAR10
            opt.netG = NETG_CIFAR10
    elif opt.dataset == 'mnist':
        opt.nc = 1
        opt.imageSize = 32
        dataset = dset.MNIST(root=opt.dataroot, download=True, transform=transforms.Compose([
                                   transforms.Scale(opt.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
        # Update opt params for mnist
        if opt.load_dict:
            opt.netD = NETD_MNIST
            opt.netG = NETG_MNIST

    return dataset


def weights_init(m):
    '''
    Custom weights initialization called on netG and netD
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def main(opt):
    cuda = opt.cuda; visualize = opt.visualize
    print(f"cuda = {cuda}, visualize = {opt.visualize}")
    if visualize:
        netD_loss_logger = VisdomPlotLogger('line', opts={'title': 'Discriminator (NetD) Loss'})
        netG_loss_logger = VisdomPlotLogger('line', opts={'title': 'Generator (NetG) Loss'})

    cudnn.benchmark = True
    opt.manualSeed = random.randint(1, 10000)  # fix seed
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    ## Path to generative samples storage
    if opt.experiment is None:
        opt.experiment = 'samples'
    os.system('mkdir {0}'.format(opt.experiment))

    if USE_CUDA and not opt.cuda:
        utils.eprint("WARNING: CUDA device available, please run with CUDA")

    dataset = __getDataSet(opt)
    assert dataset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchSize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))
    nz = int(opt.nz)
    nc = int(opt.nc)


    input = torch.FloatTensor(opt.batchSize, nc, opt.imageSize, opt.imageSize)
    noise = torch.FloatTensor(opt.batchSize, nz, 1, 1)
    fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1)
    one = torch.FloatTensor([1])
    mone = one * -1

    ## Get Networks
    netG, netD = getNetworks(opt)
    if opt.cuda:
        if isDebug: print("Using CUDA")
        netD.cuda()
        netG.cuda()
        input = input.cuda()
        one, mone = one.cuda(), mone.cuda()
        noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

    ## Setup Optimizers
    optimizerG, optimizerD = getOptimizers(opt, netG, netD)

    gen_iterations = 0
    for epoch in tqdm(range(opt.niter)):
        data_iter = iter(dataloader)
        i = 0
        while i < len(dataloader):
            ############################
            # (1) Update D network
            ###########################
            for p in netD.parameters():  # reset requires_grad
                p.requires_grad = True  # they are set to False below in netG update

            # train the discriminator Diters times
            if gen_iterations < 25 or gen_iterations % 500 == 0:
                Diters = 100
            else:
                Diters = opt.Diters
            j = 0
            while j < Diters and i < len(dataloader):
                j += 1

                # clamp parameters to a cube
                for p in netD.parameters():
                    p.data.clamp_(opt.clamp_lower, opt.clamp_upper)

                data = data_iter.next()
                i += 1

                # train with real
                real_cpu, _ = data
                netD.zero_grad()
                batch_size = real_cpu.size(0)

                if opt.cuda:
                    real_cpu = real_cpu.cuda()
                input.resize_as_(real_cpu).copy_(real_cpu)
                inputv = Variable(input)

                errD_real = netD(inputv)
                errD_real.backward(one)

                # train with fake
                noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
                noisev = Variable(noise, volatile=True)  # totally freeze netG
                fake = Variable(netG(noisev).data)
                inputv = fake
                errD_fake = netD(inputv)
                errD_fake.backward(mone)
                errD = errD_real - errD_fake
                optimizerD.step()

            ############################
            # (2) Update G network
            ###########################
            for p in netD.parameters():
                p.requires_grad = False  # to avoid computation
            netG.zero_grad()
            # in case our last batch was the tail batch of the dataloader,
            # make sure we feed a full batch of noise
            noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
            noisev = Variable(noise)
            fake = netG(noisev)
            errG = netD(fake)
            errG.backward(one)
            optimizerG.step()
            gen_iterations += 1

            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
                  % (epoch, opt.niter, i, len(dataloader), gen_iterations,
                     errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

            if visualize:
                netD_loss_logger.log(epoch, errD.data[0])
                netD_loss_logger.log(epoch, errG.data[0])

            if gen_iterations % 500 == 0 or ((gen_iterations % 100 == 0) and (opt.dataset == 'mnist')):
                real_cpu = real_cpu.mul(0.5).add(0.5)
                vutils.save_image(real_cpu, '{0}/{1}/real_samples.png'.format(opt.experiment, opt.dataset))
                fake = netG(Variable(fixed_noise, volatile=True))
                fake.data = fake.data.mul(0.5).add(0.5)
                vutils.save_image(fake.data, '{0}/{1}/fake_samples_{2}.png'.format(opt.experiment, opt.dataset, gen_iterations))

        # do checkpointing
        if opt.niter > 25:
            if epoch % 10 == 0:
                torch.save(netG.state_dict(), '{0}/{1}/netG_epoch_{2}.pth'.format(opt.experiment, opt.dataset, epoch))
                torch.save(netD.state_dict(), '{0}/{1}/netD_epoch_{2}.pth'.format(opt.experiment, opt.dataset, epoch))
        else:
            torch.save(netG.state_dict(), '{0}/{1}/netG_epoch_{2}.pth'.format(opt.experiment, opt.dataset, epoch))
            torch.save(netD.state_dict(), '{0}/{1}/netD_epoch_{2}.pth'.format(opt.experiment, opt.dataset, epoch))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Pass configurations here")
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--dataset', required=False, type=str, default=DATASET, help='cifar10 | imagenet | folder | lfw ')
    parser.add_argument('--debug', default=False, help='True | False')
    parser.add_argument('--workers', type=int, default=NUM_WORKERS, help='number of data loading workers')
    parser.add_argument('--batchSize', type=int, default=BATCH_SIZE, help='input batch size')
    parser.add_argument('--imageSize', type=int, default=IMG_SIZE, help='the height / width of the input image to network')
    parser.add_argument('--nc', type=int, default=IMG_CHANNELS, help='input image channels')
    parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
    parser.add_argument('--ngf', type=int, default=NGF, help='number of generator features')
    parser.add_argument('--ndf', type=int, default=NDF, help='number of discriminator features')
    parser.add_argument('--niter', type=int, default=NUM_EPOCHS, help='number of epochs to train for')
    parser.add_argument('--lrD', type=float, default=LR_D, help='learning rate for Critic, default=0.00005')
    parser.add_argument('--lrG', type=float, default=LR_G, help='learning rate for Generator, default=0.00005')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
    parser.add_argument('--visualize', action='store_true', help='Enables Visdom')
    parser.add_argument('--cuda', action='store', default=None, type=int, help='Enables cuda')
    parser.add_argument('--load_dict', action='store_true', help='Loads saved state dicts')
    parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
    parser.add_argument('--netG', default='', help="path to netG (to continue training)")
    parser.add_argument('--netD', default='', help="path to netD (to continue training)")
    parser.add_argument('--clamp_lower', type=float, default=-0.01)
    parser.add_argument('--clamp_upper', type=float, default=0.01)
    parser.add_argument('--Diters', type=int, default=D_ITERS, help='number of D iters per each G iter')
    parser.add_argument('--noBN', action='store_true', help='use batchnorm or not (only for DCGAN)')
    parser.add_argument('--mlp_G', action='store_true', help='use MLP for G')
    parser.add_argument('--mlp_D', action='store_true', help='use MLP for D')
    parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc')
    parser.add_argument('--experiment', default=None, help='Where to store samples and models')
    parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
    opt = parser.parse_args()

    if opt.cuda is not None and opt.cuda >= 0:
        if torch.cuda.is_available():
            torch.cuda.set_device(opt.cuda)
            opt.cuda = True
        else:
            opt.cuda = False

    try:
        from eval.helper import *
        from eval.BLEU_score import *
        from visdom import Visdom
        import torchnet as tnt
        from torchnet.engine import Engine
        from torchnet.logger import VisdomPlotLogger, VisdomTextLogger, VisdomLogger
        canVisualize = True
    except ImportError as ie:
        print("Could not import vizualization imports. ", file=sys.stderr)
        canVisualize = False
    opt.visualize = True if (opt.visualize and canVisualize) else False

    main(opt)