import os, sys
sys.path.append(os.getcwd())

import time
import functools

import numpy as np
import tensorflow as tf
import scipy.misc

import tflib as lib
import tflib.ops.linear
import tflib.ops.conv2d
import tflib.ops.batchnorm
import tflib.ops.deconv2d
import tflib.save_images
import tflib.celebA_64x64
import tflib.small_imagenet
import tflib.ops.layernorm
import tflib.plot

FLAGS = tf.app.flags.FLAGS

# Configurations
tf.app.flags.DEFINE_string('mode', 'wgan-gp',
                           "loss function option. [wgan-gp | dcgan | wgan | lsgan]")
tf.app.flags.DEFINE_string('data_dir', 'data/celebA_64x64', "data directory")
tf.app.flags.DEFINE_string('train_dir', 'train', "image output direcotory")
tf.app.flags.DEFINE_string('summary_dir', 'summary', "tensorboard summary directory")
tf.app.flags.DEFINE_integer('max_runtime', 20, "maximum run time in min")
tf.app.flags.DEFINE_integer('max_iter', 500, "maximum mini-batch iterations")
tf.app.flags.DEFINE_float('LAMBDA', 10., "gradient penalty lambda parameter")
tf.app.flags.DEFINE_float('gen_l1_weight', 0.9, "weight of L1 difference in generator loss")
tf.app.flags.DEFINE_integer('architecture', 0, "index of architecture")

# Download 64x64 ImageNet at http://image-net.org/small/download.php and
# fill in the path to the extracted files here!
DATA_DIR = FLAGS.data_dir
SUMMARY_DIR = FLAGS.summary_dir
GEN_L1_WEIGHT = FLAGS.gen_l1_weight # Weighting factor for L1 difference in generator loss
TRAIN_DIR = FLAGS.train_dir # Directory to output image
MODE = FLAGS.mode # dcgan, wgan, wgan-gp, lsgan
ITERS = FLAGS.max_iter # How many iterations to train for
LAMBDA = FLAGS.LAMBDA # Gradient penalty lambda hyperparameter

if len(DATA_DIR) == 0:
    raise Exception('Please specify path to data directory in gan_64x64.py!')

DIM = 64 # Model dimensionality
K = 4 # How much to downsample
CRITIC_ITERS = 5 # How many iterations to train the critic for
N_GPUS = 1 # Number of GPUs
BATCH_SIZE = 16 # Batch size. Must be a multiple of N_GPUS
INPUT_DIM = 16*16*3 # Number of pixels in each input
OUTPUT_DIM = 64*64*3 # Number of pixels in each iamge
DELETE_TRAIN_DIR=True

lib.print_model_settings(locals().copy())

# create summary dir
if not tf.gfile.Exists(FLAGS.summary_dir):
    tf.gfile.MakeDirs(FLAGS.summary_dir)

# clean directory
if DELETE_TRAIN_DIR:
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
        tf.gfile.MakeDirs(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)

# architecture dictionary
def get_architectures():
    ARCHITECTURE_TABLE = {
        # Baseline (G: DCGAN, D: DCGAN)
        0: (DCGANGenerator, DCGANDiscriminator),

        # No BN and constant number of filts in G
        1: (WGANPaper_CrippledDCGANGenerator, DCGANDiscriminator),

        # 512-dim 4-layer ReLU MLP G
        2: (FCGenerator, DCGANDiscriminator),

        # No normalization anywhere
        3: (functools.partial(DCGANGenerator, bn=False),
            functools.partial(DCGANDiscriminator, bn=False)),

        # Gated multiplicative nonlinearities everywhere
        4: (MultiplicativeDCGANGenerator, MultiplicativeDCGANDiscriminator),

        # tanh nonlinearities everywhere
        5: (functools.partial(DCGANGenerator, bn=True, nonlinearity=tf.tanh),
            functools.partial(DCGANDiscriminator, bn=True, nonlinearity=tf.tanh)),

        # 101-layer ResNet G and D
        6: (ResnetGenerator, ResnetDiscriminator)
    }
    return ARCHITECTURE_TABLE

def GeneratorAndDiscriminator():
    """
    Choose which generator and discriminator architecture to use by
    uncommenting one of these lines.
    """
    table = get_architectures()
    if FLAGS.architecture <= len(table):
        return table[FLAGS.architecture]

    raise Exception('You must choose an architecture!')

DEVICES = ['/gpu:{}'.format(i) for i in range(N_GPUS)]

def LeakyReLU(x, alpha=0.2):
    return tf.maximum(alpha*x, x)

def ReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(name+'.Linear', n_in, n_out, inputs, initialization='he')
    return tf.nn.relu(output)

def LeakyReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(name+'.Linear', n_in, n_out, inputs, initialization='he')
    return LeakyReLU(output)

def Batchnorm(name, axes, inputs):
    if ('Discriminator' in name) and (MODE == 'wgan-gp'):
        if axes != [0,2,3]:
            raise Exception('Layernorm over non-standard axes is unsupported')
        return lib.ops.layernorm.Layernorm(name,[1,2,3],inputs)
    else:
        return lib.ops.batchnorm.Batchnorm(name,axes,inputs,fused=True)

def pixcnn_gated_nonlinearity(a, b):
    return tf.sigmoid(a) * tf.tanh(b)

def SubpixelConv2D(*args, **kwargs):
    kwargs['output_dim'] = 4*kwargs['output_dim']
    output = lib.ops.conv2d.Conv2D(*args, **kwargs)
    output = tf.transpose(output, [0,2,3,1])
    output = tf.depth_to_space(output, 2)
    output = tf.transpose(output, [0,3,1,2])
    return output

def ResidualBlock(name, input_dim, output_dim, filter_size, inputs, resample=None, he_init=True):
    """
    resample: None, 'down', or 'up'
    """
    if resample=='down':
        conv_shortcut = functools.partial(lib.ops.conv2d.Conv2D, stride=2)
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim//2)
        conv_1b       = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim//2, output_dim=output_dim//2, stride=2)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim//2, output_dim=output_dim)
    elif resample=='up':
        conv_shortcut = SubpixelConv2D
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim//2)
        conv_1b       = functools.partial(lib.ops.deconv2d.Deconv2D, input_dim=input_dim//2, output_dim=output_dim//2)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim//2, output_dim=output_dim)
    elif resample==None:
        conv_shortcut = lib.ops.conv2d.Conv2D
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim,  output_dim=input_dim//2)
        conv_1b       = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim//2,  output_dim=output_dim/2)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim//2, output_dim=output_dim)

    else:
        raise Exception('invalid resample value')

    if output_dim==input_dim and resample==None:
        shortcut = inputs # Identity skip-connection
    else:
        shortcut = conv_shortcut(name+'.Shortcut', input_dim=input_dim, output_dim=output_dim, filter_size=1,
                                 he_init=False, biases=True, inputs=inputs)

    output = inputs
    output = tf.nn.relu(output)
    output = conv_1(name+'.Conv1', filter_size=1, inputs=output, he_init=he_init, weightnorm=False)
    output = tf.nn.relu(output)
    output = conv_1b(name+'.Conv1B', filter_size=filter_size, inputs=output, he_init=he_init, weightnorm=False)
    output = tf.nn.relu(output)
    output = conv_2(name+'.Conv2', filter_size=1, inputs=output, he_init=he_init, weightnorm=False, biases=False)
    output = Batchnorm(name+'.BN', [0,2,3], output)

    return shortcut + (0.3*output)

# ! Generators

def FCGenerator(n_samples, noise=None, FC_DIM=512, input_dim=INPUT_DIM):
    if noise is None:
        noise = tf.random_normal([n_samples, input_dim])

    output = ReLULayer('Generator.1', input_dim, FC_DIM, noise)
    output = ReLULayer('Generator.2', FC_DIM, FC_DIM, output)
    output = ReLULayer('Generator.3', FC_DIM, FC_DIM, output)
    output = ReLULayer('Generator.4', FC_DIM, FC_DIM, output)
    output = lib.ops.linear.Linear('Generator.Out', FC_DIM, OUTPUT_DIM, output)

    output = tf.tanh(output)

    return output

def DCGANGenerator(
        n_samples, noise=None, dim=DIM, input_dim=INPUT_DIM,
        k=K, bn=True, nonlinearity=tf.nn.relu):

    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)
    
    if noise is None:
        noise = tf.random_normal([n_samples, input_dim])
        output = lib.ops.linear.Linear(
            'Generator.Input', 256, (dim//k)*(dim//k)*8*dim, noise)
        output = tf.reshape(output, [-1, 8*dim, dim//k, dim//k])
        if bn:
            output = Batchnorm('Generator.BN1', [0,2,3], output)
            output = nonlinearity(output)
    else:
        # downsampled data as input (noise)
        # input (noise) dimension [batchsize, 3*(dim/K)*(dim/K)]
        # decode twice to tensor of [batchsize, 8*dim, 4, 4]
        output = tf.reshape(noise, [-1, 3, dim//k, dim//k])
        output = tflib.ops.conv2d.Conv2D(
            'Generator.Encoder1.1', 3, 4*dim, 5, output, stride=2)
        if bn:
            output = Batchnorm('Generator.BN1.1', [0,2,3], output)
            output = nonlinearity(output)
            
        output = tflib.ops.conv2d.Conv2D(
            'Generator.Encode1.2', 4*dim, 8*dim, 5, output, stride=2)
        if bn:
            output = Batchnorm('Generator.BN1.2', [0, 2, 3], output)
            output = nonlinearity(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.2', 8*dim, 4*dim, 5, output)
    if bn:
        output = Batchnorm('Generator.BN2', [0,2,3], output)
    output = nonlinearity(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.3', 4*dim, 2*dim, 5, output)
    if bn:
        output = Batchnorm('Generator.BN3', [0,2,3], output)
    output = nonlinearity(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.4', 2*dim, dim, 5, output)
    if bn:
        output = Batchnorm('Generator.BN4', [0,2,3], output)
    output = nonlinearity(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.5', dim, 3, 5, output)
    output = tf.tanh(output)
    
    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return tf.reshape(output, [-1, OUTPUT_DIM])

def WGANPaper_CrippledDCGANGenerator(
        n_samples, noise=None, dim=DIM, input_dim=INPUT_DIM):
    if noise is None:
        noise = tf.random_normal([n_samples, input_dim])

    output = lib.ops.linear.Linear('Generator.Input', input_dim, 4*4*dim, noise)
    output = tf.nn.relu(output)
    output = tf.reshape(output, [-1, dim, 4, 4])

    output = lib.ops.deconv2d.Deconv2D('Generator.2', dim, dim, 5, output)
    output = tf.nn.relu(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.3', dim, dim, 5, output)
    output = tf.nn.relu(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.4', dim, dim, 5, output)
    output = tf.nn.relu(output)

    output = lib.ops.deconv2d.Deconv2D('Generator.5', dim, 3, 5, output)
    output = tf.tanh(output)

    return tf.reshape(output, [-1, OUTPUT_DIM])

def ResnetGenerator(n_samples, noise=None, dim=DIM, input_dim=INPUT_DIM):
    if noise is None:
        noise = tf.random_normal([n_samples, input_dim])

    output = lib.ops.linear.Linear('Generator.Input', input_dim, 4*4*8*dim, noise)
    output = tf.reshape(output, [-1, 8*dim, 4, 4])

    for i in range(6):
        output = ResidualBlock('Generator.4x4_{}'.format(i), 8*dim, 8*dim, 3, output, resample=None)
    output = ResidualBlock('Generator.Up1', 8*dim, 4*dim, 3, output, resample='up')
    for i in range(6):
        output = ResidualBlock('Generator.8x8_{}'.format(i), 4*dim, 4*dim, 3, output, resample=None)
    output = ResidualBlock('Generator.Up2', 4*dim, 2*dim, 3, output, resample='up')
    for i in range(6):
        output = ResidualBlock('Generator.16x16_{}'.format(i), 2*dim, 2*dim, 3, output, resample=None)
    output = ResidualBlock('Generator.Up3', 2*dim, 1*dim, 3, output, resample='up')
    for i in range(6):
        output = ResidualBlock('Generator.32x32_{}'.format(i), 1*dim, 1*dim, 3, output, resample=None)
    output = ResidualBlock('Generator.Up4', 1*dim, dim//2, 3, output, resample='up')
    for i in range(5):
        output = ResidualBlock('Generator.64x64_{}'.format(i), dim/2, dim/2, 3, output, resample=None)

    output = lib.ops.conv2d.Conv2D('Generator.Out', dim//2, 3, 1, output, he_init=False)
    output = tf.tanh(output / 5.)

    return tf.reshape(output, [-1, OUTPUT_DIM])


def MultiplicativeDCGANGenerator(n_samples, noise=None, dim=DIM, bn=True, input_dim=INPUT_DIM):
    if noise is None:
        noise = tf.random_normal([n_samples, input_dim])

    output = lib.ops.linear.Linear('Generator.Input', input_dim, 4*4*8*dim*2, noise)
    output = tf.reshape(output, [-1, 8*dim*2, 4, 4])
    if bn:
        output = Batchnorm('Generator.BN1', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.deconv2d.Deconv2D('Generator.2', 8*dim, 4*dim*2, 5, output)
    if bn:
        output = Batchnorm('Generator.BN2', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.deconv2d.Deconv2D('Generator.3', 4*dim, 2*dim*2, 5, output)
    if bn:
        output = Batchnorm('Generator.BN3', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.deconv2d.Deconv2D('Generator.4', 2*dim, dim*2, 5, output)
    if bn:
        output = Batchnorm('Generator.BN4', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.deconv2d.Deconv2D('Generator.5', dim, 3, 5, output)
    output = tf.tanh(output)

    return tf.reshape(output, [-1, OUTPUT_DIM])

# ! Discriminators

def MultiplicativeDCGANDiscriminator(inputs, dim=DIM, bn=True):
    output = tf.reshape(inputs, [-1, 3, 64, 64])

    output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, dim*2, 5, output, stride=2)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.conv2d.Conv2D('Discriminator.2', dim, 2*dim*2, 5, output, stride=2)
    if bn:
        output = Batchnorm('Discriminator.BN2', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*dim, 4*dim*2, 5, output, stride=2)
    if bn:
        output = Batchnorm('Discriminator.BN3', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = lib.ops.conv2d.Conv2D('Discriminator.4', 4*dim, 8*dim*2, 5, output, stride=2)
    if bn:
        output = Batchnorm('Discriminator.BN4', [0,2,3], output)
    output = pixcnn_gated_nonlinearity(output[:,::2], output[:,1::2])

    output = tf.reshape(output, [-1, 4*4*8*dim])
    output = lib.ops.linear.Linear('Discriminator.Output', 4*4*8*dim, 1, output)

    return tf.reshape(output, [-1])


def ResnetDiscriminator(inputs, dim=DIM):
    output = tf.reshape(inputs, [-1, 3, 64, 64])
    output = lib.ops.conv2d.Conv2D('Discriminator.In', 3, dim//2, 1, output, he_init=False)

    for i in range(5):
        output = ResidualBlock('Discriminator.64x64_{}'.format(i), dim/2, dim/2, 3, output, resample=None)
    output = ResidualBlock('Discriminator.Down1', dim//2, dim*1, 3, output, resample='down')
    for i in range(6):
        output = ResidualBlock('Discriminator.32x32_{}'.format(i), dim*1, dim*1, 3, output, resample=None)
    output = ResidualBlock('Discriminator.Down2', dim*1, dim*2, 3, output, resample='down')
    for i in range(6):
        output = ResidualBlock('Discriminator.16x16_{}'.format(i), dim*2, dim*2, 3, output, resample=None)
    output = ResidualBlock('Discriminator.Down3', dim*2, dim*4, 3, output, resample='down')
    for i in range(6):
        output = ResidualBlock('Discriminator.8x8_{}'.format(i), dim*4, dim*4, 3, output, resample=None)
    output = ResidualBlock('Discriminator.Down4', dim*4, dim*8, 3, output, resample='down')
    for i in range(6):
        output = ResidualBlock('Discriminator.4x4_{}'.format(i), dim*8, dim*8, 3, output, resample=None)

    output = tf.reshape(output, [-1, 4*4*8*dim])
    output = lib.ops.linear.Linear('Discriminator.Output', 4*4*8*dim, 1, output)

    return tf.reshape(output / 5., [-1])


def FCDiscriminator(inputs, FC_DIM=512, n_layers=3):
    output = LeakyReLULayer('Discriminator.Input', OUTPUT_DIM, FC_DIM, inputs)
    for i in range(n_layers):
        output = LeakyReLULayer('Discriminator.{}'.format(i), FC_DIM, FC_DIM, output)
    output = lib.ops.linear.Linear('Discriminator.Out', FC_DIM, 1, output)

    return tf.reshape(output, [-1])

def DCGANDiscriminator(inputs, dim=DIM, bn=True, nonlinearity=LeakyReLU):
    output = tf.reshape(inputs, [-1, 3, 64, 64])

    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)

    output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, dim, 5, output, stride=2)
    output = nonlinearity(output)

    output = lib.ops.conv2d.Conv2D('Discriminator.2', dim, 2*dim, 5, output, stride=2)
    if bn:
        output = Batchnorm('Discriminator.BN2', [0,2,3], output)
    output = nonlinearity(output)

    output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*dim, 4*dim, 5, output, stride=2)
    if bn:
        output = Batchnorm('Discriminator.BN3', [0,2,3], output)
    output = nonlinearity(output)

    output = lib.ops.conv2d.Conv2D('Discriminator.4', 4*dim, 8*dim, 5, output, stride=2)
    if bn:
        output = Batchnorm('Discriminator.BN4', [0,2,3], output)
    output = nonlinearity(output)

    output = tf.reshape(output, [-1, 4*4*8*dim])
    output = lib.ops.linear.Linear('Discriminator.Output', 4*4*8*dim, 1, output)

    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return tf.reshape(output, [-1])

# kernel for downsampling
arr = np.zeros([K, K, 3, 3])
arr[:,:,0,0] = 1.0/(K*K)
arr[:,:,1,1] = 1.0/(K*K)
arr[:,:,2,2] = 1.0/(K*K)
_downsample_weight = tf.constant(arr, dtype=tf.float32)

def downsample(data, method='conv'):
    data = tf.reshape(data, [-1, 3, DIM, DIM])
    # BCHW -> BHWC
    data = tf.transpose(data, [0, 2, 3, 1])
    if method == 'conv':
        data = tf.nn.conv2d(data, _downsample_weight,
                            strides=[1, K, K, 1], padding='SAME')
    elif method == 'area':
        data = tf.image.resize_area(data, [DIM//K, DIM//K])
    # BHWC -> BCHW
    data = tf.transpose(data, [0, 3, 1, 2])
    data = tf.reshape(data, [-1, 3 * DIM//K * DIM//K])
    return data

Generator, Discriminator = GeneratorAndDiscriminator()

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session:

    all_real_data_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, 3, 64, 64])
    if tf.__version__.startswith('1.'):
        split_real_data_conv = tf.split(all_real_data_conv, len(DEVICES))
    else:
        split_real_data_conv = tf.split(0, len(DEVICES), all_real_data_conv)

    gen_l1_costs, gen_gan_costs = [], []
    gen_costs, disc_costs = [],[]

    for device_index, (device, real_data_conv) in enumerate(zip(DEVICES, split_real_data_conv)):
        with tf.device(device):
            real_data = 2*((tf.cast(real_data_conv, tf.float32)/255.)-.5)
            real_data = tf.reshape(real_data, [BATCH_SIZE//len(DEVICES), OUTPUT_DIM])
            # downsampled (by K) as generator input
            real_data_downsampled = downsample(real_data)
            fake_data = Generator(BATCH_SIZE//len(DEVICES), noise=real_data_downsampled)
            
            disc_real = Discriminator(real_data)
            disc_fake = Discriminator(fake_data)

            if MODE == 'wgan':
                gen_cost = tf.reduce_mean(disc_fake)
                disc_cost = tf.reduce_mean(disc_real) - tf.reduce_mean(disc_fake)

            elif MODE == 'wgan-gp':
                gen_cost = tf.reduce_mean(disc_fake)
                disc_cost = tf.reduce_mean(disc_real) - tf.reduce_mean(disc_fake)

                alpha = tf.random_uniform(
                    shape=[BATCH_SIZE//len(DEVICES),1], 
                    minval=0.,
                    maxval=1.
                )
                differences = fake_data - real_data
                interpolates = real_data + (alpha*differences)
                gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
                slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
                gradient_penalty = tf.reduce_mean((slopes-1.)**2)
                disc_cost += LAMBDA*gradient_penalty

            elif MODE == 'dcgan':
                try: # tf pre-1.0 (bottom) vs 1.0 (top)
                    gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake,
                                                                                      labels=tf.ones_like(disc_fake)))
                    disc_cost =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake,
                                                                                        labels=tf.zeros_like(disc_fake)))
                    disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real,
                                                                                        labels=tf.ones_like(disc_real)))                    
                except Exception as e:
                    gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_fake, tf.ones_like(disc_fake)))
                    disc_cost =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_fake, tf.zeros_like(disc_fake)))
                    disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(disc_real, tf.ones_like(disc_real)))                    
                disc_cost /= 2.

            elif MODE == 'lsgan':
                gen_cost = tf.reduce_mean((disc_fake - 1)**2)
                disc_cost = (tf.reduce_mean((disc_real - 1)**2) + tf.reduce_mean((disc_fake - 0)**2))/2.

            else:
                raise Exception()

            # add L1 difference to penalty
            fake_data_downsampled = downsample(fake_data)
            gen_l1_cost = tf.reduce_mean(
                tf.abs(fake_data_downsampled - real_data_downsampled))

            gen_l1_costs.append(gen_l1_cost)
            gen_gan_costs.append(gen_cost)

            gen_cost = GEN_L1_WEIGHT * gen_l1_cost + (1 - GEN_L1_WEIGHT) * gen_cost

            gen_costs.append(gen_cost)
            disc_costs.append(disc_cost)

    gen_cost = tf.add_n(gen_costs) / len(DEVICES)
    disc_cost = tf.add_n(disc_costs) / len(DEVICES)
    gen_gan_cost = tf.add_n(gen_gan_costs) / len(DEVICES)
    gen_l1_cost = tf.add_n(gen_l1_costs) / len(DEVICES)
    tf.summary.scalar('gen gan loss', gen_gan_cost, collections=['scalars'])
    tf.summary.scalar('gen l1 diff', gen_l1_cost, collections=['scalars'])
    tf.summary.scalar('gen loss', gen_cost, collections=['scalars'])
    tf.summary.scalar('disc loss', disc_cost, collections=['scalars'])

    if MODE == 'wgan':
        gen_train_op = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(
            gen_cost, var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True)
        disc_train_op = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(disc_cost,
                                             var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True)

        clip_ops = []
        for var in lib.params_with_name('Discriminator'):
            clip_bounds = [-.01, .01]
            clip_ops.append(tf.assign(var, tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])))
        clip_disc_weights = tf.group(*clip_ops)

    elif MODE == 'wgan-gp':
        gen_train_op = tf.train.AdamOptimizer(
            learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(
                gen_cost,var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True)
        disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost,
                                           var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True)

    elif MODE == 'dcgan':
        gen_train_op = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(gen_cost,
                                          var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True)
        disc_train_op = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(disc_cost,
                                           var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True)

    elif MODE == 'lsgan':
        gen_train_op = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(gen_cost,
                                             var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True)
        disc_train_op = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(disc_cost,
                                              var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True)

    else:
        raise Exception()

#     # For generating samples
#     fixed_noise = tf.constant(np.random.normal(size=(BATCH_SIZE, INPUT_DIM)).astype('float32'))
#     all_fixed_noise_samples = []
#     for device_index, device in enumerate(DEVICES):
#         n_samples = BATCH_SIZE // len(DEVICES)
#         all_fixed_noise_samples.append(Generator(n_samples, noise=fixed_noise[device_index*n_samples:(device_index+1)*n_samples]))
#     if tf.__version__.startswith('1.'):
#         all_fixed_noise_samples = tf.concat(all_fixed_noise_samples, axis=0)
#     else:
#         all_fixed_noise_samples = tf.concat(0, all_fixed_noise_samples)

#     def generate_image(iteration):
#         # add image to summary
#         samples_reshaped = tf.reshape(
#             all_fixed_noise_samples, (BATCH_SIZE, 3, DIM, DIM))
#         samples_reshaped = tf.transpose(samples_reshaped, [0, 2, 3, 1])
#         image_op = tf.summary.image(
#             'generator output', samples_reshaped)
#         image_summary = session.run(image_op)
#         summary_writer.add_summary(image_summary, iteration)

#         samples = session.run(all_fixed_noise_samples)
#         samples = ((samples+1.)*(255.99/2)).astype('int32')
#         lib.save_images.save_images(samples.reshape((BATCH_SIZE, 3, 64, 64)), 'samples_{}.png'.format(iteration))

    
    def generate_test_image(iteration, real_data, fake_data,  max_samples=10):
        feature = tf.reshape(real_data_downsampled, [-1, 3, DIM//K, DIM//K])
        # BCHW -> BHWC
        feature = (tf.transpose(feature, [0, 2, 3, 1]) + 1)/2.
        nearest = tf.image.resize_nearest_neighbor(feature, [DIM, DIM])
        nearest = tf.maximum(tf.minimum(nearest, 1.), 0.)
        bicubic = tf.image.resize_bicubic(feature, [DIM, DIM])
        bicubic = tf.maximum(tf.minimum(bicubic, 1.), 0.)
        fake_data = (tf.reshape(fake_data, [-1, 3, DIM, DIM]) + 1.)/2.
        fake_data = tf.transpose(fake_data, [0, 2, 3, 1])
        real_data = tf.reshape(real_data, [-1, 3, DIM, DIM])
        real_data = tf.transpose(real_data, [0, 2, 3, 1])
        real_data = (real_data + 1.) / 2.
        clipped = tf.maximum(tf.minimum(fake_data, 1.), 0.)
        image = tf.concat([nearest, bicubic, clipped, real_data], 2)

        feed_dict = {real_data_conv: test_data}
        image_col = tf.summary.image('generator output', image, max_samples)
        image_summary = session.run(image_col, feed_dict=feed_dict)
        summary_writer.add_summary(image_summary, iteration)

        image = image[0:max_samples,:,:,:]
        image = tf.concat([image[i,:,:,:] for i in range(max_samples)], 0)
        clipped = clipped[0:max_samples, :, :, :]
        clipped = tf.concat([clipped[i, :, :, :] for i in range(max_samples)], 1)

        image, clipped = session.run([image, clipped], feed_dict=feed_dict)
        
        filename_1 = 'batch%06d_image.png' % iteration
        filename_2 = 'batch%06d_row.png' % iteration
        filename_1 = os.path.join(TRAIN_DIR, filename_1)
        filename_2 = os.path.join(TRAIN_DIR, filename_2)
        scipy.misc.toimage(image, cmin=0., cmax=1.).save(filename_1)
        scipy.misc.toimage(clipped, cmin=0., cmax=1.).save(filename_2)
        print("Saved %s %s" % (filename_1, filename_2))

        



    # Dataset iterator and test set (for visualization) 
    train_gen, test_data = lib.celebA_64x64.load(BATCH_SIZE, data_dir=DATA_DIR)
    #train_gen, dev_gen = lib.small_imagenet.load(BATCH_SIZE, data_dir=DATA_DIR)

    def inf_train_gen():
        while True:
            for (images,) in train_gen():
                yield images

    # Save a batch of ground-truth samples
    _x = next(inf_train_gen())
    _x_r = session.run(real_data, feed_dict={real_data_conv: _x})
    _x_r = ((_x_r+1.)*(255.99/2)).astype('int32')
    lib.save_images.save_images(_x_r.reshape((BATCH_SIZE, 3, 64, 64)), 'samples_groundtruth.png')

    # Train loop
    merged_scalars = tf.summary.merge_all(key='scalars')
    summary_writer = tf.summary.FileWriter(SUMMARY_DIR, session.graph)

    session.run(tf.global_variables_initializer())
    gen = inf_train_gen()
    all_start_time = time.time()
    for iteration in range(ITERS):
        start_time = time.time()
        # finish if run overtime
        total_elapsed = (start_time - all_start_time) / 60.
        if total_elapsed > FLAGS.max_runtime:
            break

        # Train generator
        if iteration > 0:
            _ = session.run(gen_train_op, feed_dict={all_real_data_conv: _data})

        # Train critic
        if (MODE == 'dcgan') or (MODE == 'lsgan'):
            disc_iters = 1
        else:
            disc_iters = CRITIC_ITERS
        for i in range(disc_iters):
            _data = next(gen)
            _disc_cost, _ = session.run([disc_cost, disc_train_op], feed_dict={all_real_data_conv: _data})
            if MODE == 'wgan':
                _ = session.run([clip_disc_weights])

        lib.plot.plot('train disc cost', _disc_cost)
        lib.plot.plot('time', time.time() - start_time)
        #print('iter={0} disc_loss={1:.3g} time={2:.2g}'.format(
        #    iteration, _disc_cost, time.time() - start_time))

        if iteration % 10 == 0:
            merged_summary = session.run(merged_scalars, feed_dict={all_real_data_conv: _data})
            summary_writer.add_summary(merged_summary, iteration)

        if iteration % 200 == 9:
            t = time.time()
            #dev_disc_costs = []
            #for (images,) in dev_gen():
            #    _dev_disc_cost = session.run(disc_cost, feed_dict={all_real_data_conv: _data}) 
            #    dev_disc_costs.append(_dev_disc_cost)
            #lib.plot.plot('dev disc cost', np.mean(dev_disc_costs))
            generate_test_image(iteration, real_data, fake_data)

        if (iteration < 5) or (iteration % 200 == 199):
            lib.plot.flush()

        lib.plot.tick()


if __name__ == '__main__':
    tf.app.run()