""" Code adapted from https://github.com/watsonyanghx/GAN_Lib_Tensorflow SNGAN with projection ResNet for conditional generation of CIFAR-10 """ from datetime import datetime import os import sys import logging sys.path.append(os.getcwd()) import numpy as np import tensorflow as tf import time import functools import locale import common.misc import common.inception.inception_score_ import common as lib import common.ops.linear import common.ops.conv2d import common.ops.embedding import common.ops.normalization import common.plot # Download CIFAR-10 (Python version) at # https://www.cs.toronto.edu/~kriz/cifar.html and fill in the path to the extracted files here! DATA_DIR = '../data/cifar10/cifar-10-batches-py/' if len(DATA_DIR) == 0: raise Exception('Please specify path to data directory in gan_cifar.py!') flags = tf.app.flags flags.DEFINE_string("dataset", 'cifar', "Dataset") flags.DEFINE_string("algorithm", 'rcgan', "Algorithm [rcgan, rcgan-u, biased, unbiased]") flags.DEFINE_float("alpha", 0.8, "1 - noise level") flags.DEFINE_string("run", '0', "run name") flags.DEFINE_string("log_file", None, "logging file") flags.DEFINE_string("parent_dir", '.', "parent directory for checkpoints") flags.DEFINE_string("expt_dir", None, "directory for expts") flags.DEFINE_integer("inception_freq", 2500, "frequncy of inception score calculation") flags.DEFINE_integer("sample_freq", 2500, "frequncy of dev cost calc. and sample pics") flags.DEFINE_integer("generated_label_accuracy_freq", 2500, "frequncy of generated label accruacy") flags.DEFINE_integer("sample_save_freq", 0, "frequncy of saving samples") flags.DEFINE_integer("batch_size", 64, "batch size") flags.DEFINE_integer("niters", 50000, "no. of batches") flags.DEFINE_float("lr", 2.0e-4, "learning rate") flags.DEFINE_integer("ngpus", 2, "no. of gpus") flags.DEFINE_boolean("multi_gpu_multi_batch", True, 'whether to multiply batch_size with number of gpus' 'and divide nof. iterations by nof. gpus') flags.DEFINE_boolean("confuse_init", False, "whether to initialize confusion matrix with identity") flags.DEFINE_float("confuse_init_diag", 0.2, "intial confusion matrix with diagonal entry") flags.DEFINE_float("confuse_multiplier", 1.0, "learning rate multiplier for learnable confusion matrix ") flags.DEFINE_boolean("confuse_lr_decay", False, 'whether to decay confusion matrix estimation learning rate') flags.DEFINE_boolean("perm_classifier", False, 'whether to real fake classifier or not.') flags.DEFINE_float("perm_multiplier", 1.0, 'whether to real fake classifier or not.') flags.DEFINE_string("perm_type", 'linear', 'type of real fake classifier to use [linear, 2layer].') flags.DEFINE_boolean("restore", True, 'whether to restore from past checkpoint') flags.DEFINE_boolean( "perm_gen_label_acc", False, 'whether to calculate generated label accuracy' 'by taking min. value over all permutation of labels') flags.DEFINE_string("log_level", 'info', 'logging level [info, debug]') FLAGS = flags.FLAGS if FLAGS.log_file is None: raise ValueError('flag log_file is required') # dataset = str(sys.argv[1]) # ALGORITHM = str(sys.argv[2]) # ALPHA = float(sys.argv[3]) # run = str(sys.argv[4]) # log_file = str(sys.argv[5]) dataset = FLAGS.dataset ALGORITHM = FLAGS.algorithm ALPHA = FLAGS.alpha run = FLAGS.run log_file = FLAGS.log_file if FLAGS.log_level == 'debug': log_level = logging.DEBUG elif FLAGS.log_level == 'info': log_level = logging.INFO logging.basicConfig( filename=log_file, level=log_level, format='%(asctime)s %(levelname)-8s %(message)s') logging.info('alpha = {}'.format(ALPHA)) C_ALPHA = ((1-ALPHA)/9.0)*np.ones((10,10)) + (ALPHA - (1-ALPHA)/9.0)*np.eye((10)) if dataset == "cifar": import common.data.cifar10 as dataset_ OUTPUT_DIM = 3072 # Number of pixels in CIFAR10 (32*32*3) IMG_SIZE = 32 IMG_DIM = 3 INCEPTION_FREQUENCY = 5000 # 1000 # How frequently to calculate Inception score SAMPLE_FREQUENCY = 100 SAMPLE_SAVE_FREQUENCY = 0 # 5000 GENERATED_LABEL_ACCURACY_FREQ = 5000 DIR = os.path.join(FLAGS.parent_dir, ALGORITHM + '_alpha' + str(ALPHA)+ '_run-' + run + '_' + datetime.now().strftime("%Y%m%d-%H%M%S")) if dataset == "mnist": import common.data.mnist10 as dataset_ OUTPUT_DIM = 1024 # Number of pixels in CIFAR10 (32*32*3) IMG_SIZE = 32 IMG_DIM = 1 INCEPTION_FREQUENCY = 10000000 # How frequently to calculate Inception score SAMPLE_FREQUENCY = 50 SAMPLE_SAVE_FREQUENCY = 1000 DIR = './run_mnist_' + ALGORITHM + '_' + str(ALPHA) + '_' + run if FLAGS.expt_dir is not None: DIR = '{}/{}'.format(FLAGS.parent_dir, FLAGS.expt_dir) INCEPTION_FREQUENCY = FLAGS.inception_freq SAMPLE_FREQUENCY = FLAGS.sample_freq SAMPLE_SAVE_FREQUENCY = FLAGS.sample_save_freq GENERATED_LABEL_ACCURACY_FREQ = FLAGS.generated_label_accuracy_freq if not os.path.exists(DIR): os.mkdir(DIR) DIR = DIR + '/' BATCH_SIZE = 64 # Critic batch size GEN_BS_MULTIPLE = 2 # Generator batch size, as a multiple of BATCH_SIZE ITERS = 100000 # How many iterations to train for ITERS = 50000 # How many iterations to train for # ITERS = 3000 # DEBUG # ITERS = 20 # DEBUG BATCH_SIZE = FLAGS.batch_size ITERS = FLAGS.niters Z_DIM = 128 # dimension of the noise input to generator DIM_G = 128 # Generator dimensionality DIM_D = 128 # Critic dimensionality NORMALIZATION_G = True # Use batchnorm in generator? NORMALIZATION_D = False # Use batchnorm (or layernorm) in critic? LR = 0.0002 # 2e-4 # Initial learning rate DECAY = True # Whether to decay LR over learning N_CRITIC = 5 # 5 # Critic steps per generator steps LR = FLAGS.lr CONDITIONAL = True # Whether to train a conditional or unconditional model ACGAN = False # If CONDITIONAL, whether to use ACGAN or "vanilla" conditioning ACGAN_SCALE = 1. # How to scale the critic's ACGAN loss relative to WGAN loss ACGAN_SCALE_G = 0.1 # How to scale generator's ACGAN loss relative to WGAN loss # SPECTRAL_NORM_UPDATE_OPS = "spectral_norm_update_ops" # WORD2VEC_FILE = np.load(os.path.join(DATA_DIR, 'glove_y.npy')).astype('float32') WORD2VEC_FILE = None VOCAB_SIZE = 10 EMBEDDING_DIM = 300 # 620 CHECKPOINT_DIR = os.path.join(DIR, 'checkpoint') LOSS_TYPE = 'HINGE' # 'Goodfellow', 'HINGE', 'WGAN', 'WGAN-GP' SOFT_PLUS = False RESTORE = True CONCAT_LABEL = False # whether concat label to 'z' in Generator. RESTORE = FLAGS.restore if CONDITIONAL and (not ACGAN) and (not NORMALIZATION_D): logging.warning("WARNING! Conditional model without normalization in D might be effectively unconditional!") N_GPUS = FLAGS.ngpus if N_GPUS not in [1, 2]: raise Exception('Only 1 or 2 GPUs supported!') DEVICES = ['/gpu:{}'.format(i) for i in range(N_GPUS)] if len(DEVICES) == 1: # Hack because the code assumes 2 GPUs DEVICES = [DEVICES[0], DEVICES[0]] if FLAGS.multi_gpu_multi_batch: BATCH_SIZE = BATCH_SIZE*N_GPUS ITERS = ITERS//N_GPUS lib.print_model_settings(locals().copy()) common.misc.record_setting(os.path.join(DIR, 'scripts')) def nonlinearity(x, activation_fn='relu', leakiness=0.2): if activation_fn == 'relu': return tf.nn.relu(x) if activation_fn == 'lrelu': assert 0 < leakiness <= 1, "leakiness must be <= 1" return tf.maximum(x, leakiness * x) def Normalize(name, inputs, labels=None): """This is messy, but basically it chooses between batchnorm, layernorm, their conditional variants, or nothing, depending on the value of `name` and the global hyperparam flags.""" with tf.variable_scope(name): if not CONDITIONAL: labels = None if CONDITIONAL and ACGAN and ('D.' in name): labels = None if ('D.' in name) and NORMALIZATION_D: return lib.ops.normalization.layer_norm(name, [1, 2, 3], inputs) elif ('G.' in name) and NORMALIZATION_G: if labels is not None: outputs = lib.ops.normalization.cond_batchnorm(name, [0, 1, 2], inputs, labels=labels, n_labels=10) return outputs else: outputs = lib.ops.normalization.batch_norm(inputs, fused=True) return outputs else: return inputs def ConvMeanPool(inputs, output_dim, filter_size=3, stride=1, name=None, spectral_normed=False, update_collection=None, inputs_norm=False, he_init=True, biases=True): output = lib.ops.conv2d.Conv2D(inputs, inputs.shape.as_list()[-1], output_dim, filter_size, stride, name, spectral_normed=spectral_normed, update_collection=update_collection, he_init=he_init, biases=biases) # output = tf.nn.avg_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID') output = tf.add_n( [output[:, ::2, ::2, :], output[:, 1::2, ::2, :], output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :]]) / 4. return output def MeanPoolConv(inputs, output_dim, filter_size=3, stride=1, name=None, spectral_normed=False, update_collection=None, inputs_norm=False, he_init=True, biases=True): output = inputs output = tf.add_n( [output[:, ::2, ::2, :], output[:, 1::2, ::2, :], output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :]]) / 4. # output = tf.nn.avg_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], padding='VALID') output = lib.ops.conv2d.Conv2D(output, output.shape.as_list()[-1], output_dim, filter_size, stride, name, spectral_normed=spectral_normed, update_collection=update_collection, he_init=he_init, biases=biases) return output def UpsampleConv(inputs, output_dim, filter_size=3, stride=1, name=None, spectral_normed=False, update_collection=None, inputs_norm=False, he_init=True, biases=True): output = inputs output = tf.concat([output, output, output, output], axis=3) output = tf.depth_to_space(output, 2) # w, h = inputs.shape.as_list()[1], inputs.shape.as_list()[2] # output = tf.image.resize_images(inputs, [w * 2, h * 2]) output = lib.ops.conv2d.Conv2D(output, output.shape.as_list()[-1], output_dim, filter_size, stride, name, spectral_normed=spectral_normed, update_collection=update_collection, he_init=he_init, biases=biases) return output def ResidualBlock(inputs, input_dim, output_dim, filter_size, name, spectral_normed=False, update_collection=None, inputs_norm=False, resample=None, labels=None, biases=True): """resample: None, 'down', or 'up'. """ if resample == 'down': conv_1 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim) conv_2 = functools.partial(ConvMeanPool, output_dim=output_dim) conv_shortcut = ConvMeanPool elif resample == 'up': conv_1 = functools.partial(UpsampleConv, output_dim=output_dim) conv_shortcut = UpsampleConv conv_2 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim) elif resample is None: conv_shortcut = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim) conv_1 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=output_dim) conv_2 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim) else: raise Exception('invalid resample value') if output_dim == input_dim and resample is None: shortcut = inputs # Identity skip-connection else: shortcut = conv_shortcut(inputs=inputs, output_dim=output_dim, filter_size=1, name=name + '.Shortcut', spectral_normed=spectral_normed, update_collection=update_collection, he_init=False, biases=biases) output = inputs output = Normalize(name + '.N1', output, labels=labels) output = nonlinearity(output) # if resample == 'up': # output = nonlinearity(output) # else: # output = lrelu(output, leakiness=0.2) output = conv_1(inputs=output, filter_size=filter_size, name=name + '.Conv1', spectral_normed=spectral_normed, update_collection=update_collection, he_init=True, biases=biases) output = Normalize(name + '.N2', output, labels=labels) output = nonlinearity(output) # if resample == 'up': # output = nonlinearity(output) # else: # output = lrelu(output, leakiness=0.2) output = conv_2(inputs=output, filter_size=filter_size, name=name + '.Conv2', spectral_normed=spectral_normed, update_collection=update_collection, he_init=True, biases=biases) return shortcut + output def OptimizedResBlockDisc1(inputs, spectral_normed=False, update_collection=None, inputs_norm=False, biases=True): conv_1 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=IMG_DIM, output_dim=DIM_D) conv_2 = functools.partial(ConvMeanPool, output_dim=DIM_D) conv_shortcut = MeanPoolConv shortcut = conv_shortcut(inputs=inputs, output_dim=DIM_D, filter_size=1, name='D.Block.1.Shortcut', spectral_normed=spectral_normed, update_collection=update_collection, he_init=False, biases=biases) output = inputs output = conv_1(inputs=output, filter_size=3, name='D.Block.1.Conv1', spectral_normed=spectral_normed, update_collection=update_collection, he_init=True, biases=biases) output = nonlinearity(output) # output = lrelu(output, leakiness=0.2) output = conv_2(inputs=output, filter_size=3, name='D.Block.1.Conv2', spectral_normed=spectral_normed, update_collection=update_collection, he_init=True, biases=biases) return shortcut + output def Generator(n_samples_, labels, noise=None, reuse=False): with tf.variable_scope("Generator", reuse=reuse): if noise is None: noise = tf.random_normal([n_samples_, Z_DIM]) output = lib.ops.linear.Linear(noise, 128, 4 * 4 * DIM_G * 8, 'G.Input') output = tf.reshape(output, [-1, 4, 4, DIM_G * 8]) output = ResidualBlock(output, DIM_G * 8, DIM_G * 2, 3, 'G.Block.1', resample='up', labels=labels, biases=True) output = ResidualBlock(output, DIM_G * 2, DIM_G * 2, 3, 'G.Block.2', resample='up', labels=labels, biases=True) output = ResidualBlock(output, DIM_G * 2, DIM_G * 2, 3, 'G.Block.3', resample='up', labels=labels, biases=True) output = Normalize('G.OutputNorm', output, labels) output = nonlinearity(output) output = lib.ops.conv2d.Conv2D(output, DIM_G * 2, IMG_DIM, 3, 1, 'G.Output', he_init=False) output = tf.tanh(output) # return tf.reshape(tf.transpose(output, [0, 3, 1, 2], name='NHWC_to_NCHW'), [-1, OUTPUT_DIM]) return tf.reshape(output, [-1, OUTPUT_DIM]) def Discriminator(inputs, labels, update_collection=None, reuse=False): with tf.variable_scope("Discriminator", reuse=reuse): if ALGORITHM in ("unbiased", "rcgan-u"): labels_disc = None else: labels_disc = labels output = tf.reshape(inputs, [-1, IMG_SIZE, IMG_SIZE, IMG_DIM]) output = OptimizedResBlockDisc1(output, spectral_normed=True, update_collection=update_collection, biases=True) output = ResidualBlock(output, DIM_D, DIM_D, 3, 'D.Block.2', spectral_normed=True, update_collection=update_collection, resample='down', labels=labels_disc, biases=True) output = ResidualBlock(output, DIM_D, DIM_D, 3, 'D.Block.3', spectral_normed=True, update_collection=update_collection, resample=None, labels=labels_disc, biases=True) output = ResidualBlock(output, DIM_D, DIM_D, 3, 'D.Block.4', spectral_normed=True, update_collection=update_collection, resample=None, labels=labels_disc, biases=True) output = ResidualBlock(output, DIM_D, DIM_D, 3, 'D.Block.5', spectral_normed=True, update_collection=update_collection, resample=None, labels=labels_disc, biases=True) output = ResidualBlock(output, DIM_D, DIM_D, 3, 'D.Block.6', spectral_normed=True, update_collection=update_collection, resample=None, labels=labels_disc, biases=True) output = nonlinearity(output) # output = lrelu(output, leakiness=0.2) output = tf.reduce_mean(output, axis=[1, 2]) output_wgan = lib.ops.linear.Linear(output, DIM_D, 1, 'D.Output', spectral_normed=True, update_collection=update_collection) output_wgan = tf.reshape(output_wgan, [-1]) return output, output_wgan def Discriminator_projection(labels, update_collection=None, reuse=False): with tf.variable_scope("Discriminator", reuse=reuse): embedding_y = lib.ops.embedding.embed_y(labels, VOCAB_SIZE, EMBEDDING_DIM, word2vec_file=WORD2VEC_FILE) embedding_y = lib.ops.linear.Linear(embedding_y, EMBEDDING_DIM, DIM_D, 'D.Embedding_y', spectral_normed=True, update_collection=update_collection, biases=True) # (N, DIM_D) return embedding_y def generated_label_accuracy(samples, labels, confusion_matrix=None): with tf.gfile.GFile('./resnet-110/graph_optimized.pb', 'rb') as f: graph_def_optimized = tf.GraphDef() graph_def_optimized.ParseFromString(f.read()) if confusion_matrix is not None: _confusion_matrix = confusion_matrix confusion_matrix = np.zeros_like(confusion_matrix, dtype=int) confusion_matrix[ np.arange(confusion_matrix.shape[0]), np.argmax(_confusion_matrix, axis=-1)] = 1 _labels = labels labels = np.zeros([_labels.shape[0], VOCAB_SIZE], dtype=float) labels[np.arange(labels.shape[0]), _labels] = 1 labels[:] = labels.dot(confusion_matrix) labels = np.argmax(labels, axis=-1) G = tf.Graph() with G.as_default(): gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) config = tf.ConfigProto(device_count = {'GPU': 1}, gpu_options=gpu_options) num_test = 100 with tf.Session(config=config) as sess: pred_softmax = tf.import_graph_def(graph_def_optimized, return_elements=['infer_softmax:0']) x = G.get_tensor_by_name('import/resnet_test_batch:0') softmax = sess.run(pred_softmax, feed_dict={x: samples}) acc = (labels == np.argmax(softmax, axis=-1)).astype(float).mean() logging.info('generated label accuracy: {}'.format(acc)) return acc def perm_classifier(x, reuse=False): if FLAGS.perm_type == 'linear': with tf.variable_scope("Discriminator", reuse=reuse): # 1 layer NN hidden_layer = lib.ops.linear.Linear( tf.reshape(x, [-1, OUTPUT_DIM]), OUTPUT_DIM, VOCAB_SIZE, 'D.d_perm_classifier_h1', spectral_normed=True, biases=True, reuse=reuse) logits = hidden_layer elif FLAGS.perm_type == '2layer': with tf.variable_scope("Discriminator", reuse=reuse): # 1 layer NN hidden_layer = lib.ops.linear.Linear( tf.reshape(x, [-1, OUTPUT_DIM]), OUTPUT_DIM, 128, 'D.d_perm_classifier_h1', spectral_normed=True, biases=True, reuse=reuse) hidden_layer = lib.ops.linear.Linear( hidden_layer, 128, VOCAB_SIZE, 'D.d_perm_classifier_h2', spectral_normed=True, biases=True, reuse=reuse) logits = hidden_layer else: raise ValueError('Unknown perm_type {}'.format(FLAGS.perm_type)) return logits def sigmoid_cross_entropy_with_logits(x, y): try: return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) except: return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y) def main(_): config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True with tf.Session(config=config) as session: # confusion matrix variable for estimation if ALGORITHM == 'rcgan-u': if not FLAGS.confuse_init: confusion_logits = tf.get_variable( 'confusion_logits', dtype=tf.float32, shape=[VOCAB_SIZE, VOCAB_SIZE], trainable=True) else: if FLAGS.confuse_init_diag > 0.99 and VOCAB_SIZE == 10.: aa = 7.0 else: aa = np.log(VOCAB_SIZE*FLAGS.confuse_init_diag/ (1.-FLAGS.confuse_init_diag)) aa = min(7.0, aa) mean = 0.0 # 0.2/VOCAB_SIZE confuse_init = (0 - aa/VOCAB_SIZE + mean)*np.ones( [VOCAB_SIZE, VOCAB_SIZE], dtype=np.float32) np.fill_diagonal(confuse_init, (aa - (aa/VOCAB_SIZE) + mean)) confusion_logits = tf.get_variable( 'confusion_logits', dtype=tf.float32, initializer=tf.constant_initializer(confuse_init), shape=[VOCAB_SIZE, VOCAB_SIZE], trainable=True) confusion_matrix = tf.nn.softmax(confusion_logits, dim=-1) else: confusion_matrix = tf.constant(C_ALPHA.astype(np.float32)) _iteration = tf.placeholder(tf.int32, shape=None) all_real_data_int = tf.placeholder(tf.int32, shape=[BATCH_SIZE, OUTPUT_DIM]) all_real_labels = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) labels_splits = tf.split(all_real_labels, len(DEVICES), axis=0) all_random_labels = tf.placeholder(tf.int32, shape=[BATCH_SIZE], name='d1') labels_random_splits = tf.split(all_random_labels, len(DEVICES), axis=0) all_labels_biased = tf.placeholder(tf.int32, shape=[BATCH_SIZE], name='d2') labels_biased_splits = tf.split(all_labels_biased, len(DEVICES), axis=0) all_labels_inv_weights = tf.placeholder(tf.float32, shape=[BATCH_SIZE,VOCAB_SIZE]) labels_inv_weights_splits = tf.split(all_labels_inv_weights, len(DEVICES), axis=0) fake_data_splits = [] for i, device in enumerate(DEVICES): with tf.device(device): if i > 0: fake_data_splits.append(Generator(int(BATCH_SIZE / len(DEVICES)), labels_random_splits[i], reuse=True)) else: fake_data_splits.append(Generator(int(BATCH_SIZE / len(DEVICES)), labels_random_splits[i])) all_real_data = tf.reshape(2 * ((tf.cast(all_real_data_int, tf.float32) / 256.) - .5), [BATCH_SIZE, OUTPUT_DIM]) all_real_data += tf.random_uniform(shape=[BATCH_SIZE, OUTPUT_DIM], minval=0., maxval=1. / 128) # dequantize all_real_data = tf.reshape( tf.transpose(tf.reshape(all_real_data, [-1, IMG_DIM, IMG_SIZE, IMG_SIZE]), perm=[0, 2, 3, 1]), [-1, OUTPUT_DIM]) all_real_data_splits = tf.split(all_real_data, len(DEVICES), axis=0) #DEVICES_A = DEVICES[int(len(DEVICES) / 2):] # DEVICES_B = DEVICES[:int(len(DEVICES) / 2)] disc_costs = [] for i, device in enumerate(DEVICES): with tf.device(device): if ALGORITHM == 'rcgan-u': real_and_fake_data = all_real_data_splits[i] else: real_and_fake_data = tf.concat(values=[ all_real_data_splits[i], fake_data_splits[i], ], axis=0) if ALGORITHM in ["biased", "unbiased"]: real_and_fake_labels = tf.concat(values=[ labels_splits[i], labels_random_splits[i], ], axis=0) elif ALGORITHM == "rcgan-u": real_and_fake_labels = labels_splits[i] elif ALGORITHM == "rcgan": real_and_fake_labels = tf.concat(values=[ labels_splits[i], labels_biased_splits[i], ], axis=0) if i == 0: reuse = False else: reuse = True output, output_wgan = Discriminator(real_and_fake_data, real_and_fake_labels, update_collection=None, reuse=reuse) embedding_y = Discriminator_projection(real_and_fake_labels, update_collection=None, reuse=reuse) if ALGORITHM == "biased" or ALGORITHM == "rcgan": disc_all = output_wgan + tf.reshape(tf.reduce_sum(output*embedding_y, axis=1), [-1]) disc_real = disc_all[:int(BATCH_SIZE / len(DEVICES))] disc_fake = disc_all[int(BATCH_SIZE / len(DEVICES)):] if LOSS_TYPE == 'Goodfellow': if SOFT_PLUS: disc_real_l = -tf.reduce_mean(tf.nn.softplus(tf.log(tf.nn.sigmoid(disc_real)))) disc_fake_l = -tf.reduce_mean(tf.nn.softplus(tf.log(1 - tf.nn.sigmoid(disc_fake)))) else: disc_real_l = -tf.reduce_mean(tf.log(tf.nn.sigmoid(disc_real))) disc_fake_l = -tf.reduce_mean(tf.log(1 - tf.nn.sigmoid(disc_fake))) disc_costs.append(disc_real_l + disc_fake_l) elif LOSS_TYPE == 'HINGE': if SOFT_PLUS: disc_real_l = tf.reduce_mean(tf.nn.softplus(-tf.minimum(0., -1 + disc_real))) disc_fake_l = tf.reduce_mean(tf.nn.softplus(-tf.minimum(0., -1 - disc_fake))) else: disc_real_l = tf.reduce_mean(tf.nn.relu(1. - disc_real)) disc_fake_l = tf.reduce_mean(tf.nn.relu(1. + disc_fake)) disc_costs.append(disc_real_l + disc_fake_l) elif LOSS_TYPE == 'WGAN': if SOFT_PLUS: disc_costs.append( tf.reduce_mean(tf.nn.softplus(disc_fake)) + tf.reduce_mean(tf.nn.softplus(-disc_real))) else: disc_costs.append(tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)) elif ALGORITHM == "unbiased": disc_real_l_y = [j for j in range(VOCAB_SIZE)] for j in range(VOCAB_SIZE): real_and_fake_labels = tf.concat(values=[ tf.convert_to_tensor(j*np.ones((int(BATCH_SIZE / len(DEVICES)),)),tf.int32), labels_random_splits[i], ], axis=0) embedding_y = Discriminator_projection(real_and_fake_labels, update_collection=None, reuse = True) disc_all = output_wgan + tf.reshape(tf.reduce_sum(output*embedding_y, axis=1), [-1]) disc_real = disc_all[:int(BATCH_SIZE / len(DEVICES))] disc_fake = disc_all[int(BATCH_SIZE / len(DEVICES)):] if LOSS_TYPE == 'Goodfellow': if SOFT_PLUS: disc_real_l_y[j] = -tf.reshape(tf.nn.softplus(tf.log(tf.nn.sigmoid(disc_real))) ,[int(BATCH_SIZE / len(DEVICES)),1]) disc_fake_l = -tf.reduce_mean(tf.nn.softplus(tf.log(1 - tf.nn.sigmoid(disc_fake)))) else: disc_real_l_y[j] = -tf.reshape(tf.log(tf.nn.sigmoid(disc_real)),[int(BATCH_SIZE / len(DEVICES)),1]) disc_fake_l = -tf.reduce_mean(tf.log(1 - tf.nn.sigmoid(disc_fake))) elif LOSS_TYPE == 'HINGE': if SOFT_PLUS: disc_real_l_y[j] = tf.reshape(tf.nn.softplus(-tf.minimum(0., -1 + disc_real)) ,[int(BATCH_SIZE / len(DEVICES)),1]) disc_fake_l = tf.reduce_mean(tf.nn.softplus(-tf.minimum(0., -1 - disc_fake))) else: disc_real_l_y[j] = tf.reshape(tf.nn.relu(1. - disc_real),[int(BATCH_SIZE / len(DEVICES)),1]) disc_fake_l = tf.reduce_mean(tf.nn.relu(1. + disc_fake)) elif LOSS_TYPE == 'WGAN': if SOFT_PLUS: disc_real_l_y[j] = tf.nn.softplus(-disc_real) disc_fake_l = tf.reduce_mean(tf.nn.softplus(disc_fake)) else: disc_real_l_y[j] = -disc_real disc_fake_l = tf.reduce_mean(disc_fake) abc = tf.reduce_mean(tf.reduce_sum(tf.concat(disc_real_l_y,1)*labels_inv_weights_splits[i], axis = 1)) disc_costs.append(abc + disc_fake_l) elif ALGORITHM == "rcgan-u": disc_real = output_wgan + tf.reshape(tf.reduce_sum(output*embedding_y, axis=1), [-1]) output, output_wgan = Discriminator(fake_data_splits[i], labels_random_splits[i], update_collection=None ,reuse = True) fake_labels = tf.convert_to_tensor(np.arange(VOCAB_SIZE), tf.int32) embedding_y = Discriminator_projection(fake_labels, update_collection=None, reuse = True) disc_fake = ( tf.expand_dims(output_wgan, 1) + tf.reduce_sum(tf.expand_dims(output, 1)* tf.expand_dims(embedding_y, 0), axis=-1)) if LOSS_TYPE == 'Goodfellow': if SOFT_PLUS: disc_fake_y = -tf.nn.softplus(tf.log(1. - tf.nn.sigmoid(disc_fake))) disc_real_l = -tf.reduce_mean(tf.nn.softplus(tf.log(tf.nn.sigmoid(disc_real)))) else: disc_fake_y = -tf.log(1. - tf.nn.sigmoid(disc_fake)) disc_real_l = -tf.reduce_mean(tf.log(tf.nn.sigmoid(disc_real))) elif LOSS_TYPE == 'HINGE': if SOFT_PLUS: disc_fake_y = tf.nn.softplus(-tf.minimum(0., -1 - disc_fake)) disc_real_l = tf.reduce_mean(tf.nn.softplus(-tf.minimum(0., -1 + disc_real))) else: disc_fake_y = tf.nn.relu(1. + disc_fake) disc_real_l = tf.reduce_mean(tf.nn.relu(1. - disc_real)) elif LOSS_TYPE == 'WGAN': if SOFT_PLUS: disc_fake_y = tf.nn.softplus(disc_fake) disc_real_l = tf.reduce_mean(tf.nn.softplus(-disc_real)) else: disc_fake_y = disc_fake disc_real_l = tf.reduce_mean(-disc_real) y_fake_confuse = tf.tensordot( tf.one_hot(labels_random_splits[i], VOCAB_SIZE), confusion_matrix, axes=[[1], [0]]) abc = tf.reduce_mean(tf.reduce_sum(disc_fake_y*y_fake_confuse, axis = 1)) disc_costs.append(abc + disc_real_l) if FLAGS.perm_classifier: if i==0: reuse = False else: reuse = True perm_classifier_real_logits = perm_classifier(all_real_data_splits[i], reuse=reuse) perm_classifier_real_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits( perm_classifier_real_logits, tf.one_hot(labels_splits[i], VOCAB_SIZE))) disc_costs[-1] += 1.*perm_classifier_real_loss disc_wgan = tf.add_n(disc_costs) / len(DEVICES) tf.summary.scalar('D_wgan_cost', disc_wgan) disc_cost = disc_wgan if DECAY: decay = tf.where( tf.less(_iteration, 50000), tf.maximum(0., 1. - (tf.cast(_iteration, tf.float32) / 100000)), 0.5) else: decay = 1. tf.summary.scalar('lr', LR * decay) all_random_labels_G = tf.placeholder(tf.int32, shape=[BATCH_SIZE*GEN_BS_MULTIPLE], name='g1') labels_random_splits_G = tf.split(all_random_labels_G, len(DEVICES), axis=0) all_labels_biased_G = tf.placeholder(tf.int32, shape=[BATCH_SIZE*GEN_BS_MULTIPLE], name = 'g2') labels_biased_splits_G = tf.split(all_labels_biased_G, len(DEVICES), axis=0) gen_costs = [] for i, device in enumerate(DEVICES): with tf.device(device): n_samples = GEN_BS_MULTIPLE * int(BATCH_SIZE / len(DEVICES)) fake_data_split_G = Generator(n_samples, labels_random_splits_G[i], reuse=True) if ALGORITHM == "biased" or ALGORITHM == "unbiased": output, output_wgan = Discriminator(fake_data_split_G, labels_random_splits_G[i], update_collection="NO_OPS", reuse=True) embedding_y = Discriminator_projection(labels_random_splits_G[i], update_collection=None, reuse=True) if ALGORITHM in ["rcgan", "rcgan-u"]: output, output_wgan = Discriminator(fake_data_split_G, labels_biased_splits_G[i], update_collection="NO_OPS", reuse=True) embedding_y = Discriminator_projection(labels_biased_splits_G[i], update_collection=None, reuse=True) if ALGORITHM == "rcgan-u": fake_labels = tf.convert_to_tensor(np.arange(VOCAB_SIZE), tf.int32) embedding_y = Discriminator_projection(fake_labels, update_collection=None, reuse = True) disc_fake = ( tf.expand_dims(output_wgan, 1) + tf.reduce_sum(tf.expand_dims(output, 1)* tf.expand_dims(embedding_y, 0), axis=-1)) if LOSS_TYPE == 'Goodfellow': if SOFT_PLUS: disc_fake_y = tf.nn.softplus(-tf.log(tf.nn.sigmoid(disc_fake))) else: disc_fake_y = -tf.log(tf.nn.sigmoid(disc_fake)) elif LOSS_TYPE == 'HINGE': if SOFT_PLUS: disc_fake_y = tf.nn.softplus(-disc_fake) else: disc_fake_y = -disc_fake elif LOSS_TYPE == 'WGAN': if SOFT_PLUS: disc_fake_y = tf.nn.softplus(-disc_fake) else: disc_fake_y = -disc_fake y_fake_confuse = tf.tensordot( tf.one_hot(labels_random_splits_G[i], VOCAB_SIZE), confusion_matrix, axes=[[1], [0]]) abc = tf.reduce_mean(tf.reduce_sum(disc_fake_y*y_fake_confuse, axis = 1)) gen_costs.append(abc) else: disc_fake = output_wgan + tf.reshape(tf.reduce_sum(output*embedding_y, axis=1), [-1]) if LOSS_TYPE == 'Goodfellow': if SOFT_PLUS: gen_costs.append(tf.reduce_mean(tf.nn.softplus(-tf.log(tf.nn.sigmoid(disc_fake))))) else: gen_costs.append(-tf.reduce_mean(tf.log(tf.nn.sigmoid(disc_fake)))) elif LOSS_TYPE == 'HINGE': if SOFT_PLUS: gen_costs.append(tf.reduce_mean(tf.nn.softplus(-disc_fake))) else: gen_costs.append(-tf.reduce_mean(disc_fake)) elif LOSS_TYPE == 'WGAN': if SOFT_PLUS: gen_costs.append(tf.reduce_mean(tf.nn.softplus(-disc_fake))) else: gen_costs.append(-tf.reduce_mean(disc_fake)) if FLAGS.perm_classifier: perm_classifier_fake_logits = perm_classifier(fake_data_split_G, reuse=True) perm_classifier_fake_loss = tf.reduce_mean(sigmoid_cross_entropy_with_logits( perm_classifier_fake_logits, tf.one_hot(labels_random_splits_G[i], VOCAB_SIZE))) gen_costs[-1] += FLAGS.perm_multiplier*perm_classifier_fake_loss gen_cost = (tf.add_n(gen_costs) / len(DEVICES)) tf.summary.scalar('G_wgan_cost', gen_cost) gen_params = [var for var in tf.trainable_variables() if 'Generator' in var.name] logging.debug('\ngen_params:') for var in gen_params: logging.debug(var.name) disc_params = [var for var in tf.trainable_variables() if 'Discriminator' in var.name] logging.debug('\ndisc_params:') for var in disc_params: logging.debug(var.name) logging.debug('\ntrainable_variables.name:') for var in tf.trainable_variables(): logging.debug(var.name) disc_opt = tf.train.AdamOptimizer(learning_rate=LR * decay, beta1=0., beta2=0.9) disc_gv = disc_opt.compute_gradients(disc_cost, var_list=disc_params) disc_train_op = disc_opt.apply_gradients(disc_gv) gen_opt = tf.train.AdamOptimizer(learning_rate=LR * decay, beta1=0., beta2=0.9) gen_gv = gen_opt.compute_gradients(gen_cost, var_list=gen_params) gen_train_op = gen_opt.apply_gradients(gen_gv) confuse_train_op = tf.no_op() if ALGORITHM == 'rcgan-u': if FLAGS.confuse_lr_decay: confuse_lr = LR * decay * FLAGS.confuse_multiplier else: confuse_lr = LR * FLAGS.confuse_multiplier confuse_train_op = tf.train.AdamOptimizer(confuse_lr, beta1=0., beta2=0.9) \ .minimize(gen_cost, var_list=[confusion_logits]) # Function for generating samples frame_i = [0] fixed_noise = tf.constant(np.random.normal(size=(100, Z_DIM)).astype('float32')) # airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck sample_labels = [[k]*10 for k in range(10)] sample_labels = [k for item in sample_labels for k in item] fixed_labels = tf.constant(np.array(sample_labels, dtype='int32'))#[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 10 fixed_noise_samples = Generator(100, fixed_labels, noise=fixed_noise, reuse=True) def generate_image(frame, true_dist): samples = session.run(fixed_noise_samples) samples = ((samples + 1.) * (255. / 2)).astype('int32') common.misc.save_images(samples.reshape((100, IMG_SIZE, IMG_SIZE, IMG_DIM)), os.path.join(DIR, 'samples_{}.png'.format(frame))) # Function for calculating inception score fake_labels_100 = tf.cast(tf.random_uniform([100]) * 10, tf.int32) samples_100 = Generator(100, fake_labels_100, reuse=True) def get_inception_score(n): # For inception_score_new2 all_samples = [] for i in range(int(n / 100)): all_samples.append(session.run(samples_100)) all_samples = np.concatenate(all_samples, axis=0) all_samples = all_samples.reshape((-1, IMG_SIZE, IMG_SIZE, IMG_DIM)).transpose(0, 3, 1, 2) return common.inception.inception_score_.get_inception_score(all_samples) label_100_list = [label for label in range(10) for _ in range(10)] fake_deterministic_labels_100 = tf.cast(tf.constant(label_100_list), tf.int32) deterministic_samples_100 = Generator(100, fake_deterministic_labels_100, reuse=True) def save_samples(n): all_samples = [] all_labels = [] for i in range(int(n / 100)): all_samples.append(session.run(deterministic_samples_100)) all_labels.append(label_100_list) all_samples = np.concatenate(all_samples, axis=0) all_labels = np.concatenate(all_labels, axis=0) all_samples = ((all_samples + 1.) * (255.99 / 2)).astype('int32') # all_samples = all_samples.reshape((-1, 3, 32, 32)).transpose(0, 2, 3, 1) all_samples = all_samples.reshape((-1, IMG_SIZE, IMG_SIZE, IMG_DIM)) return all_samples, all_labels # Function for reading data train_gen, dev_gen = dataset_.load(BATCH_SIZE, DATA_DIR, C_ALPHA) def inf_train_gen(): while True: for images_, labels_, labels_random_, labels_biased_, labels_inv_weights_ in train_gen(): yield images_, labels_, labels_random_, labels_biased_, labels_inv_weights_ def inf_train_gen_G(): _generator = train_gen() while True: labels_random_list = [] labels_biased_list = [] for _ in range(GEN_BS_MULTIPLE): try: _, _, labels_random_list_element, labels_biased_list_element, _ = _generator.__next__() except StopIteration: _generator = train_gen() _, _, labels_random_list_element, labels_biased_list_element, _ = _generator.__next__() labels_random_list.append(labels_random_list_element) labels_biased_list.append(labels_biased_list_element) yield (np.concatenate(labels_random_list, axis= 0), np.concatenate(labels_biased_list, axis = 0)) gen = inf_train_gen() gen_G = inf_train_gen_G() for name, grads_and_vars in [('G', gen_gv), ('D', disc_gv)]: logging.debug("{} Params:".format(name)) total_param_count = 0 for g, v in grads_and_vars: shape = v.get_shape() shape_str = ",".join([str(x) for x in v.get_shape()]) param_count = 1 for dim in shape: param_count *= int(dim) total_param_count += param_count if g is None: logging.debug("\t{} ({}) [no grad!]".format(v.name, shape_str)) else: logging.debug("\t{} ({})".format(v.name, shape_str)) logging.debug("Total param count: {}".format(locale.format("%d", total_param_count, grouping=True))) summaries_op = tf.summary.merge_all() saver = tf.train.Saver(max_to_keep=5) summary_writer = tf.summary.FileWriter(CHECKPOINT_DIR, graph=session.graph) session.run(tf.global_variables_initializer()) if RESTORE: ckpt = tf.train.latest_checkpoint(CHECKPOINT_DIR) if ckpt: logging.info('restore model from: {}...'.format(ckpt)) saver.restore(session, ckpt) _random_labels_G, _labels_biased_G = next(gen_G) inception_score_max = 0.0 gen_label_acc_max = 0.0 for iteration in range(ITERS): start_time = time.time() if ALGORITHM == 'rcgan-u' and (iteration%100==0 or iteration < 500): logging.debug('confusion_matrix: ') np.set_printoptions(precision=3, suppress=True) logging.debug('\n{}'.format(session.run(confusion_matrix))) np.set_printoptions() if 0 < iteration: _random_labels_G, _labels_biased_G = next(gen_G) #logging.debug('test1: {}'.format(_random_labels_G)) _ = session.run([gen_train_op, confuse_train_op], feed_dict={_iteration: iteration, all_random_labels_G: _random_labels_G, all_labels_biased_G: _labels_biased_G}) for i in range(N_CRITIC): _data, _labels, _random_labels, _labels_biased, _labels_inv_weights = next(gen) _disc_cost, _disc_wgan, _gen_cost, _, summaries = session.run( [disc_cost, disc_wgan, gen_cost, disc_train_op, summaries_op], feed_dict={all_real_data_int: _data, all_real_labels: _labels, all_random_labels: _random_labels, all_labels_biased: _labels_biased, all_labels_inv_weights: _labels_inv_weights, all_random_labels_G: _random_labels_G, all_labels_biased_G: _labels_biased_G, _iteration: iteration}) summary_writer.add_summary(summaries, global_step=iteration) # lib.plot.plot('cost', _disc_cost) lib.plot.plot('d_cost', _disc_wgan) lib.plot.plot('g_cost', _gen_cost) if CONDITIONAL and ACGAN: lib.plot.plot('disc_wgan', _disc_wgan) lib.plot.plot('acgan', _disc_acgan) lib.plot.plot('acc_real', _disc_acgan_acc) lib.plot.plot('acc_fake', _disc_acgan_fake_acc) if iteration % INCEPTION_FREQUENCY == INCEPTION_FREQUENCY - 1: logging.info('starting inception score computation.') inception_score = get_inception_score(50000) inception_score_max = max(inception_score_max, inception_score[0]) lib.plot.plot('inception_50k', inception_score[0]) lib.plot.plot('inception_50k_std', inception_score[1]) lib.plot.plot('inception_50k_max', inception_score_max) logging.info('finished inception score computation.') if SAMPLE_SAVE_FREQUENCY and iteration % SAMPLE_SAVE_FREQUENCY == SAMPLE_SAVE_FREQUENCY - 1: logging.info('starting saving samples.') samples_for_save, _ = save_samples(10000) np.save(os.path.join(DIR, '_samples_{}'.format(iteration)), samples_for_save) logging.info('finished saving samples.') # Calculate dev loss and generate samples every 100 iters if iteration % SAMPLE_FREQUENCY == SAMPLE_FREQUENCY - 1: logging.info('starting calculating dev cost.') dev_disc_costs = [] for images, _labels, _random_labels, _labels_biased, _labels_inv_weights in dev_gen(): _dev_disc_cost = session.run([disc_cost], feed_dict={all_real_data_int: images, all_real_labels: _labels, all_random_labels: _random_labels, all_labels_biased: _labels_biased, all_labels_inv_weights: _labels_inv_weights, }) dev_disc_costs.append(_dev_disc_cost) lib.plot.plot('dev_cost', np.mean(dev_disc_costs)) logging.info('finished calculating dev cost.') logging.info('starting generating samples.') generate_image(iteration, _data) logging.info('finished generating samples.') if iteration % GENERATED_LABEL_ACCURACY_FREQ == GENERATED_LABEL_ACCURACY_FREQ - 1: logging.info('starting calculating generated label accuracy.') generated_samples, generate_labels = save_samples(1000) accuracy = generated_label_accuracy(generated_samples, generate_labels) if gen_label_acc_max < accuracy: gen_label_acc_max = accuracy lib.plot.plot('gen_label_acc', accuracy) lib.plot.plot('gen_label_acc_max', gen_label_acc_max) logging.info('finished calculating generated label accuracy.') if (iteration < 500) or (iteration % 1000 == 999): logging.info('start flushing plots and checpoints.') lib.plot.dir_flush(DIR) if not os.path.exists(CHECKPOINT_DIR): os.mkdir(CHECKPOINT_DIR) saver.save(session, os.path.join(CHECKPOINT_DIR, 'model.ckpt'), global_step=iteration) logging.info('finished flushing plots and checpoints.') lib.plot.tick() if ITERS: summary_writer.flush() summary_writer.close() if FLAGS.perm_gen_label_acc: logging.info('starting calculating min. permuted generated label accuracy.') generated_samples, generate_labels = save_samples(1000) accuracy = generated_label_accuracy( generated_samples, generate_labels, confusion_matrix=session.run(confusion_matrix)) lib.plot.plot('gen_label_acc', accuracy) logging.info('finished calculating min. permuted generated label accuracy.') else: logging.info('starting calculating generated label accuracy.') generated_samples, generate_labels = save_samples(1000) accuracy = generated_label_accuracy(generated_samples, generate_labels) lib.plot.plot('gen_label_acc', accuracy) logging.info('finished calculating generated label accuracy.') if __name__ == '__main__': tf.app.run()