from __future__ import print_function, division
import torch.nn as nn
import numpy as np

import plot

import scipy.misc
from scipy.misc import imsave

import torch
from torch.autograd import Variable
import torch.autograd as autograd

from os.path import join
from glob import glob
from torch.utils.data import Dataset, DataLoader

import os
from skimage import io, transform
from skimage.transform import resize
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import time
import random
import cv2


os.environ['CUDA_VISIBLE_DEVICES'] = '0'


class Param:
    unet_channel = 64
    cnn_channel = 64
    batch_size = 16
    image_size = 128
    n_critic = 1
    gan_weight = 0.001
    tv_weight = 1.0
    weight_decay = 0.00
    G_learning_rate = 0.0002
    D_learning_rate = 0.0002
    out_path = '/data/haoran/unet-gan/gan_lstm_2/'


def conv_down(dim_in, dim_out):
    return nn.Sequential(
        nn.LeakyReLU(0.2),
        nn.Conv2d(dim_in, dim_out, kernel_size=4, stride=2, padding=1,bias = False),
        nn.BatchNorm2d(dim_out)
    )


def conv_up(dim_in, dim_out):
    return nn.Sequential(
        nn.ReLU(),
        nn.ConvTranspose2d(dim_in, dim_out, 4, 2, 1, bias=False),
        nn.BatchNorm2d(dim_out)
    )


class Unet(nn.Module):
    def __init__(self, unet_input_channel=3, hidden_channel=Param.unet_channel * 8):
        super(Unet, self).__init__()
        self.start = nn.Conv2d(unet_input_channel, Param.unet_channel,3,1,1)  # 128
        self.conv0 = conv_down(Param.unet_channel, Param.unet_channel)  # 64
        self.conv1 = conv_down(Param.unet_channel, Param.unet_channel * 2)  # 32
        self.conv2 = conv_down(Param.unet_channel * 2, Param.unet_channel * 4)  # 16
        self.conv3 = conv_down(Param.unet_channel * 4, Param.unet_channel * 8)  # 8
        self.conv4 = conv_down(Param.unet_channel * 8, Param.unet_channel * 8)  # 4
        self.conv5 = conv_down(Param.unet_channel * 8, Param.unet_channel * 8)  # 2
        self.conv6 = conv_down(Param.unet_channel * 8, Param.unet_channel * 8)  # 1

        self.up5 = conv_up(hidden_channel, Param.unet_channel * 8)  # 2
        self.dp5 = nn.Dropout(p=0.5)
        self.up4 = conv_up(Param.unet_channel * 8 * 2, Param.unet_channel * 8)  # 4
        self.dp4 = nn.Dropout(p=0.5)
        self.up3 = conv_up(Param.unet_channel * 8 * 2, Param.unet_channel * 8)  # 8
        self.dp3 = nn.Dropout(p=0.5)
        self.up2 = conv_up(Param.unet_channel * 8 * 2, Param.unet_channel * 4)  # 16
        self.up1 = conv_up(Param.unet_channel * 4 * 2, Param.unet_channel * 2)  # 32
        self.up0 = conv_up(Param.unet_channel * 2 * 2, Param.unet_channel)  # 64
        self.end = conv_up(Param.unet_channel * 2, 3)  # 128

        ## weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                nn.init.kaiming_normal(m.weight.data, mode='fan_out')
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, data_in, hidden_input=None):
        start_out = self.start(data_in)
        conv0_out = self.conv0(start_out)
        conv1_out = self.conv1(conv0_out)
        conv2_out = self.conv2(conv1_out)
        conv3_out = self.conv3(conv2_out)
        conv4_out = self.conv4(conv3_out)
        conv5_out = self.conv5(conv4_out)
        conv6_out = self.conv6(conv5_out)

        mid = conv6_out  # Param.batch_size * 256 * 1 * 1

        if hidden_input is None:
            up5_out = self.up5(conv6_out)
        else:
            hidden_input = hidden_input.view(hidden_input.size(0), hidden_input.size(1), 1, 1)
            up5_out = self.up5(torch.cat((hidden_input, conv6_out), 1))

        up4_out = self.up4(torch.cat((up5_out, conv5_out), 1))
        up3_out = self.up3(torch.cat((up4_out, conv4_out), 1))
        up2_out = self.up2(torch.cat((up3_out, conv3_out), 1))
        up1_out = self.up1(torch.cat((up2_out, conv2_out), 1))
        up0_out = self.up0(torch.cat((up1_out, conv1_out), 1))
        out = self.end(torch.cat((up0_out, conv0_out), 1))
        out = F.sigmoid(out)
        return out, mid


def conv_stage(dim_in, dim_out):
    return nn.Sequential(
        nn.Conv2d(dim_in, dim_out, 4, 2, 1,bias=False),
        nn.LeakyReLU(0.2),
        nn.BatchNorm2d(dim_out)
    )


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv0 = nn.Conv2d(3, Param.cnn_channel, 3, 1, 1,bias=False)
        self.conv1 = conv_stage(Param.cnn_channel, Param.cnn_channel * 2)
        self.conv2 = conv_stage(Param.cnn_channel * 2, Param.cnn_channel * 4)
        self.conv3 = nn.Conv2d(Param.cnn_channel * 4, 1, 4, 1, 1)
        self.bn0 = nn.BatchNorm2d(Param.cnn_channel)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                nn.init.kaiming_normal(m.weight.data, mode='fan_out')
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, data_in):  # map   channel
        conv0_out = self.conv0(data_in)  # 128   64
        conv0_out = self.bn0(conv0_out)
        conv1_out = self.conv1(conv0_out)  # 64    128
        conv2_out = self.conv2(conv1_out)  # 32    256
        out = self.conv3(conv2_out)  # 31    1
        out = F.sigmoid(out)
        return out


class ParisData(object):
    def __init__(self, csv_file, trans=None):
        self.lines = pd.read_csv(csv_file)
        self.trans = trans

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        image_pos = self.lines.ix[idx, 0]
        image = io.imread(image_pos)
        image = image.astype(np.float)
        h,w = image.shape[:2]
        if(h<w):
            factor = h/350.0
            w = w/factor
            h = 350
        else:
            factor = w/350.0
            h = h/factor
            w = 350
        image = transform.resize(image, (int(h), int(w), 3))
        image_id = self.lines.ix[idx, 1]
        sample = {'image': image, 'id': image_id}
        if self.trans is not None:
            sample = self.trans(sample)
        return sample


class RandCrop(object):
    def __call__(self, sample):
        image = sample['image']
        image_id = sample['id']
        h, w = image.shape[:2]
        sx = random.randint(0, h - Param.image_size)
        sy = random.randint(0, w - Param.image_size)
        image = image[sx:(sx + Param.image_size), sy:(sy + Param.image_size)]
        image = image.transpose((2, 0, 1))
        if(random.randint(0,1)):
            image = image[:,:,::-1]
        image /= 255.0
        image_trans = np.array(image)
        return {'image': torch.FloatTensor(image_trans), 'id': torch.Tensor([image_id])}


def inf_get(train):
    while (True):
        for x in train:
            yield x['image']


def destroy(image, crop_size=64):
    re = image.clone().cuda()
    '''
    re[:, :, int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size),
    int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size)] = torch.zeros(
        Param.batch_size, 3, crop_size, crop_size).cuda()
    '''
    re[:, 0, int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size),
    int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size)] = torch.zeros(
        Param.batch_size, 1, crop_size, crop_size).fill_(0.45703125).cuda()
    re[:, 1, int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size),
    int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size)] = torch.zeros(
        Param.batch_size, 1, crop_size, crop_size).fill_(0.40625).cuda()
    re[:, 2, int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size),
    int((Param.image_size - crop_size) / 2):int((Param.image_size - crop_size) / 2 + crop_size)] = torch.zeros(
        Param.batch_size, 1, crop_size, crop_size).fill_(0.48046875).cuda()

    return re


class Net_G(nn.Module):
    def __init__(self):
        super(Net_G, self).__init__()
        self.unet_1 = Unet(3, Param.unet_channel * 8)
        self.unet_2 = Unet(6, Param.unet_channel * 8 * 3)
        self.rnn = nn.LSTMCell(Param.unet_channel * 8, Param.unet_channel * 8 * 2)

    def forward(self, data_1, data_2, h0, c0):
        #print(data_1.size())
        unet_out_1, unet_mid_1 = self.unet_1(data_1)
        h1, c1 = self.rnn(unet_mid_1.view(Param.batch_size, -1), (h0, c0))
        unet_out_2, unet_mid_2 = self.unet_2(torch.cat((data_1, unet_out_1), 1), h1)
        return unet_out_1, unet_out_2


class Net_D(nn.Module):
    def __init__(self):
        super(Net_D, self).__init__()
        self.cnn1 = CNN()
        self.cnn2 = CNN()

    def forward(self, data_32, data_0):
        out1 = self.cnn1(data_32)
        out2 = self.cnn2(data_0)
        return out1, out2


def save_image_plus(x, save_path):
    x = (255.99 * x).astype('uint8')
    x = x.transpose(0, 1, 3, 4, 2)
    nh, nw = x.shape[:2]
    h = x.shape[2]
    w = x.shape[3]
    img = np.zeros((h * nh, w * nw, 3))
    for i in range(nh):
        for j in range(nw):
            img[i * h:i * h + h, j * w:j * w + w] = x[i][j]
    imsave(save_path, img)

def cal_tv(image):
    temp = image.clone()
    temp[:,:,:Param.image_size-1,:] = image[:,:,1:,:]
    re = ((image-temp)**2).mean()
    temp = image.clone()
    temp[:,:,:,:Param.image_size-1] = image[:,:,:,1:]
    re += ((image-temp)**2).mean()
    return re

def main():
    one = torch.FloatTensor([1.0]).cuda()
    mone = torch.FloatTensor([-1.0]).cuda()
    ones_31 = torch.zeros(Param.batch_size, 1, 31, 31).fill_(1.0).type(torch.FloatTensor).cuda()
    mones_31 = torch.zeros(Param.batch_size, 1, 31, 31).fill_(-1.0).type(torch.FloatTensor).cuda()
    zeros_31 = torch.zeros(Param.batch_size, 1, 31, 31).type(torch.FloatTensor).cuda()

    mask = torch.ones(Param.batch_size, 3, 128, 128)
    mask[:, :, 32:32 + 64, 32:32 + 64] = torch.zeros(Param.batch_size, 3, 64, 64)
    mask = Variable(mask.type(torch.FloatTensor).cuda(), requires_grad=False)

    h0 = torch.zeros(Param.batch_size, Param.unet_channel * 8 * 2).cuda()
    c0 = torch.zeros(Param.batch_size, Param.unet_channel * 8 * 2).cuda()

    netG = Net_G().cuda()
    netD = Net_D().cuda()

    #netG = nn.DataParallel(netG, device_ids=[0, 1])
    #netD = nn.DataParallel(netD, device_ids=[0, 1])

    netG.load_state_dict(torch.load('/data/haoran/unet-gan/gan_lstm_2/netG_59999.pickle'))
    netD.load_state_dict(torch.load('/data/haoran/unet-gan/gan_lstm_2/netD_59999.pickle'))


    opt_G = optim.Adam(netG.parameters(), lr=Param.G_learning_rate, betas = (0.5,0.999), weight_decay=Param.weight_decay)
    opt_D = optim.Adam(netD.parameters(), lr=Param.D_learning_rate, betas = (0.5,0.999), weight_decay=Param.weight_decay)

    trainset = ParisData('paris.csv', RandCrop())
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=Param.batch_size, shuffle=True, num_workers=2,
                                               drop_last=True)
    train_data = inf_get(train_loader)

    epoch = 0
    maxepoch = 200000
    bce_loss = nn.BCELoss()

    while (epoch < maxepoch):
        start_time = time.time()
        # step D
        for p in netD.parameters():
            p.requires_grad = True
        #for D_step in range(Param.n_critic):
        ###################################
        ######   D  #######################
        ###################################
        real_data = train_data.next()
        # print(real_data)
        real_data = real_data.cuda()
        real_data_64 = destroy(real_data, 64)
        real_data_32 = destroy(real_data, 32)

        real_data_64 = Variable(real_data_64)
        real_data_32 = Variable(real_data_32)
        real_data_0 = Variable(real_data)

        netD.zero_grad()
        p_real_32, p_real_0 = netD(real_data_32, real_data_0)
        target = Variable(ones_31)

        #print(p_real_48.size())
        real_loss_32 = bce_loss(p_real_32, target)
        real_loss_0 = bce_loss(p_real_0, target)

        fake_data_32, fake_data_0 = netG(real_data_64,real_data_32, Variable(h0), Variable(c0))

        p_fake_32, p_fake_0 = netD( Variable(fake_data_32.data), Variable(fake_data_0.data))
        target = Variable(zeros_31)
        fake_loss_32 = bce_loss(p_fake_32, target)
        fake_loss_0 = bce_loss(p_fake_0, target)

        gan_loss = real_loss_32 + real_loss_0 + fake_loss_32 + fake_loss_0
        gan_loss.backward(retain_graph=True)

        D_cost = fake_loss_32.data[0] + fake_loss_0.data[0]
        D_cost +=  real_loss_32.data[0]  + real_loss_0.data[0]

        opt_D.step()
        ##################
        ## step G ########
        ##################
        for p in netD.parameters():
            p.requires_grad = False
        netG.zero_grad()

        l1_loss =  ((fake_data_32 - real_data_32)**2).mean() + ((fake_data_0 - real_data_0)**2).mean()

        tv_loss = cal_tv(fake_data_32) + cal_tv(fake_data_0)   
        tv_loss = tv_loss * Param.tv_weight

        p_fake_32, p_fake_0 = netD( fake_data_32,  fake_data_0)
        target = Variable(ones_31)
        fake_loss_32 = bce_loss(p_fake_32, target)
        fake_loss_0 = bce_loss(p_fake_0, target)

        gan_loss = fake_loss_32 + fake_loss_0
        gan_loss = gan_loss * Param.gan_weight

        gan_loss.backward(retain_graph=True)
        l1_loss.backward(one, retain_graph=True)
        tv_loss.backward(one, retain_graph=True)

        G_cost = fake_loss_32.data[0] + fake_loss_0.data[0]
        G_cost += l1_loss.data[0]
        opt_G.step()
        print('epoch: ' + str(epoch) + ' l1_loss: ' + str(l1_loss.data[0]))

        # Write logs and save samples
        #print(D_cost.size())
        #print(G_cost.size())
        os.chdir(Param.out_path)
        plot.plot('train D cost', D_cost)
        plot.plot('time', time.time() - start_time)
        plot.plot('train G cost', G_cost)
        plot.plot('train l1 loss', l1_loss.data.cpu().numpy())

        if epoch % 1000 == 999:
            #  real_data = train_data.next()
            out_image = torch.cat(
                (
                    fake_data_32.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    fake_data_0.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_64.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_32.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size),
                    real_data_0.data.view(Param.batch_size, 1, 3, Param.image_size, Param.image_size)
                ),
                1
            )
            # out_image.transpose(0,1,3,4,2)
            save_image_plus(out_image.cpu().numpy(), Param.out_path + 'train_image_{}.jpg'.format(epoch))
            #  save_images.save_images(real_data.cpu().numpy(), '/home/z/wgan_unet/layer/real_{}.jpg'.format(epoch))
            #  destroy_data = destroy(real_data)
            #  save_images.save_images(destroy_data.cpu().numpy(), '/home/lmc-09/PycharmProjects/unet_wgan/out_image/destroy_{}.jpg'.format(epoch))
            #  destroy_data = Variable(destroy_data.cuda())
            #  fake_data = netG(destroy_data).data
            #  save_images.save_images(fake_data.data.cpu().numpy(), '/home/z/wgan_unet/layer/fake_{}.jpg'.format(epoch))

        if (epoch < 5) or (epoch % 100 == 99):
            plot.flush()
        plot.tick()

        if epoch % 20000 == 19999:
            torch.save(netD.state_dict(),Param.out_path+ 'netD_{}.pickle'.format(epoch))
            torch.save(netG.state_dict(),Param.out_path+ 'netG_{}.pickle'.format(epoch))
            #opt_D.param_groups[0]['lr'] /= 10.0
            #opt_G.param_groups[0]['lr'] /= 10.0
        epoch += 1


if __name__ == '__main__':
    main()