# This is the code for experiments performed on the MNIST dataset for the DeLiGAN model. Minor adjustments in # the code as suggested in the comments can be done to test GAN. Corresponding details about these experiments # can be found in section 5.3 of the paper and the results showing the outputs can be seen in Fig 4. import tensorflow as tf import numpy as np from ops import * from utils import * import os import time from random import randint import cv2 import matplotlib.pylab as Plot import tsne from matplotlib.offsetbox import OffsetImage, AnnotationBbox import matplotlib import numpy as Math import sys from tensorflow.contrib.layers import batch_norm data_dir='../datasets/mnist/' results_dir='../results/mnist/' phase_train = tf.placeholder(tf.bool, name = 'phase_train') def Minibatch_Discriminator(input, num_kernels=100, dim_per_kernel=5, init=False, name='MD'): num_inputs=df_dim*4 theta = tf.get_variable(name+"/theta",[num_inputs, num_kernels, dim_per_kernel], initializer=tf.random_normal_initializer(stddev=0.05)) log_weight_scale = tf.get_variable(name+"/lws",[num_kernels, dim_per_kernel], initializer=tf.constant_initializer(0.0)) W = tf.mul(theta, tf.expand_dims(tf.exp(log_weight_scale)/tf.sqrt(tf.reduce_sum(tf.square(theta),0)),0)) W = tf.reshape(W,[-1,num_kernels*dim_per_kernel]) x = input x=tf.reshape(x, [batchsize,num_inputs]) activation = tf.matmul(x, W) activation = tf.reshape(activation,[-1,num_kernels,dim_per_kernel]) abs_dif = tf.mul(tf.reduce_sum(tf.abs(tf.sub(tf.expand_dims(activation,3),tf.expand_dims(tf.transpose(activation,[1,2,0]),0))),2), 1-tf.expand_dims(tf.constant(np.eye(batchsize),dtype=np.float32),1)) f = tf.reduce_sum(tf.exp(-abs_dif),2)/tf.reduce_sum(tf.exp(-abs_dif)) print(f.get_shape()) print(input.get_shape()) return tf.concat(1,[x, f]) def linear(x,output_dim, name="linear"): w=tf.get_variable(name+"/w", [x.get_shape()[1], output_dim]) b=tf.get_variable(name+"/b", [output_dim], initializer=tf.constant_initializer(0.0)) return tf.matmul(x,w)+b def fc_batch_norm(x, n_out, phase_train, name='bn'): beta = tf.get_variable(name + '/fc_beta', shape=[n_out], initializer=tf.constant_initializer()) gamma = tf.get_variable(name + '/fc_gamma', shape=[n_out], initializer=tf.random_normal_initializer(1., 0.02)) batch_mean, batch_var = tf.nn.moments(x, [0], name=name + '/fc_moments') ema = tf.train.ExponentialMovingAverage(decay=0.9) def mean_var_with_update(): ema_apply_op = ema.apply([batch_mean, batch_var]) with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var))) normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5) return normed def global_batch_norm(x, n_out, phase_train, name='bn'): beta = tf.get_variable(name + '/beta', shape=[n_out], initializer=tf.constant_initializer(0.)) gamma = tf.get_variable(name + '/gamma', shape=[n_out], initializer=tf.random_normal_initializer(1., 0.02)) batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name=name + '/moments') ema = tf.train.ExponentialMovingAverage(decay=0.9) def mean_var_with_update(): ema_apply_op = ema.apply([batch_mean, batch_var]) with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var))) normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5) return normed def conv(x, Wx, Wy,inputFeatures, outputFeatures, stridex=1, stridey=1, padding='SAME', transpose=False, name='conv'): w = tf.get_variable(name+"/w",[Wx, Wy, inputFeatures, outputFeatures], initializer=tf.truncated_normal_initializer(stddev=0.02)) b = tf.get_variable(name+"/b",[outputFeatures], initializer=tf.constant_initializer(0.0)) conv = tf.nn.conv2d(x, w, strides=[1,stridex,stridey,1], padding=padding) + b return conv def convt(x, outputShape, Wx=3, Wy=3, stridex=1, stridey=1, padding='SAME', transpose=False, name='convt'): w = tf.get_variable(name+"/w",[Wx, Wy, outputShape[-1], x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) b = tf.get_variable(name+"/b",[outputShape[-1]], initializer=tf.constant_initializer(0.0)) convt = tf.nn.conv2d_transpose(x, w, output_shape=outputShape, strides=[1,stridex,stridey,1], padding=padding) +b return convt def discriminator(image, Reuse=False): with tf.variable_scope('disc', reuse=Reuse): image = tf.reshape(image, [-1, 28, 28, 1]) h0 = lrelu(conv(image, 5, 5, 1, df_dim, stridex=2, stridey=2, name='d_h0_conv')) h1 = lrelu( batch_norm(conv(h0, 5, 5, df_dim,df_dim*2,stridex=2,stridey=2,name='d_h1_conv'), decay=0.9, scale=True, updates_collections=None, is_training=phase_train, reuse=Reuse, scope='d_bn1')) h2 = lrelu(batch_norm(conv(h1, 3, 3, df_dim*2, df_dim*4, stridex=2, stridey=2,name='d_h2_conv'), decay=0.9,scale=True, updates_collections=None, is_training=phase_train, reuse=Reuse, scope='d_bn2')) h3 = tf.nn.max_pool(h2, ksize=[1,4,4,1], strides=[1,1,1,1],padding='VALID') h6 = tf.reshape(h2,[-1, 4*4*df_dim*4]) h7 = Minibatch_Discriminator(h3, num_kernels=df_dim*4, name = 'd_MD') h8 = dense(tf.reshape(h7, [batchsize, -1]), df_dim*4*2, 1, scope='d_h8_lin') return tf.nn.sigmoid(h8), h8 def generator(z): with tf.variable_scope('gen'): h0 = tf.reshape(tf.nn.relu(fc_batch_norm(linear(z, gf_dim*4*4*4, name='g_h0'), gf_dim*4*4*4, phase_train, 'g_bn0')), [-1, 4, 4, gf_dim*4]) h1 = tf.nn.relu(global_batch_norm(convt(h0,[batchsize, 7, 7, gf_dim*2],3, 3, 2, 2, name='g_h1'), gf_dim*2, phase_train, 'g_bn1')) h3 = tf.nn.relu(global_batch_norm(convt(h1,[batchsize, 14, 14,gf_dim],5, 5, 2, 2, name='g_h3'), gf_dim, phase_train, 'g_bn3')) h4 = tf.tanh(convt(h3,[batchsize, 28, 28, 1], 5, 5, 2, 2, name='g_h4')) return h4 gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: batchsize = 50 imageshape = [28*28] z_dim = 30 gf_dim = 16 df_dim = 16 learningrate = 0.0005 beta1 = 0.5 images = tf.placeholder(tf.float32, [batchsize] + imageshape, name="real_images") z = tf.placeholder(tf.float32, [None, z_dim], name="z") lr1 = tf.placeholder(tf.float32, name="lr") # Our Mixture Model modifications zin = tf.get_variable("g_z", [batchsize, z_dim],initializer=tf.random_uniform_initializer(-1,1)) zsig = tf.get_variable("g_sig", [batchsize, z_dim],initializer=tf.constant_initializer(0.2)) inp = tf.add(zin,tf.mul(z,zsig)) # inp = z # Uncomment this line when training/testing baseline GAN G = generator(inp) D_prob, D_logit = discriminator(images) D_fake_prob, D_fake_logit = discriminator(G, Reuse=True) d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit, tf.ones_like(D_logit))) d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logit, tf.zeros_like(D_fake_logit))) sigma_loss = tf.reduce_mean(tf.square(zsig-1))/3 # sigma regularizer gloss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_fake_logit, tf.ones_like(D_fake_logit))) dloss = d_loss_real + d_loss_fake t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'd_' in var.name] g_vars = [var for var in t_vars if 'g_' in var.name] data = np.load(data_dir + 'mnist.npz') trainx = np.concatenate([data['trainInps']], axis=0) trainy = np.concatenate([data['trainTargs']], axis=0) trainx = 2*trainx/255.-1 data = [] # Uniformly sampling 50 images per category from the dataset for i in range(10): train = trainx[np.argmax(trainy,1)==i] data.append(train[-50:]) data = np.array(data) data = np.reshape(data,[-1,28*28]) d_optim = tf.train.AdamOptimizer(lr1, beta1=beta1).minimize(dloss, var_list=d_vars) g_optim = tf.train.AdamOptimizer(lr1, beta1=beta1).minimize(gloss + sigma_loss, var_list=g_vars) tf.initialize_all_variables().run() saver = tf.train.Saver(max_to_keep=10) counter = 1 start_time = time.time() data_size = data.shape[0] display_z = np.random.uniform(-1.0, 1.0, [batchsize, z_dim]).astype(np.float32) seed = 1 rng = np.random.RandomState(seed) train = True thres=1.0 # used to balance gan training count1=0 count2=0 t1=0.70 if train: # saver.restore(sess, tf.train.latest_checkpoint(os.getcwd()+"../results/mnist/train/")) # training a model for epoch in xrange(4000): batch_idx = data_size/batchsize batch = data[rng.permutation(data_size)] lr = learningrate * (np.minimum((4 - epoch/1000.), 3.)/3) for idx in xrange(batch_idx): batch_images = batch[idx*batchsize:(idx+1)*batchsize] batch_z = np.random.uniform(-1.0, 1.0, [batchsize, z_dim]).astype(np.float32) if count1>3: thres=min(thres+0.003, 1.0) count1=0 print('gen', thres) if count2<-1: thres=max(thres-0.003, t1) count2=0 print('disc', thres) for k in xrange(5): batch_z = np.random.normal(0, 1.0, [batchsize, z_dim]).astype(np.float32) if gloss.eval({z: batch_z, phase_train.name:False})>thres: sess.run([g_optim],feed_dict={z: batch_z, lr1:lr, phase_train.name:True}) count1+=1 count2=0 else: sess.run([d_optim],feed_dict={ images: batch_images, z: batch_z, lr1:lr, phase_train.name:True}) count2-=1 count1=0 counter += 1 if counter % 300 == 0: # Saving 49 randomly generated samples print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " % (epoch, idx, batch_idx, time.time() - start_time,)) sdata = sess.run(G,feed_dict={ z: batch_z, phase_train.name:False}) sdata = sdata.reshape(sdata.shape[0], 28, 28, 1)/2.+0.5 sdata = merge(sdata[:49],[7,7]) sdata = np.array(sdata*255.,dtype=np.int) cv2.imwrite(results_dir + "/" + str(counter) + ".png", sdata) errD_fake = d_loss_fake.eval({z: display_z, phase_train.name:False}) errD_real = d_loss_real.eval({images: batch_images, phase_train.name:False}) errG = gloss.eval({z: display_z, phase_train.name:False}) sigloss = sigma_loss.eval() print('D_real: ', errD_real) print('D_fake: ', errD_fake) print('G_err: ', errG) print('sigloss: ', sigloss) if counter % 2000 == 0: # Calculating the Nearest Neighbours corresponding to the generated samples sdata = sess.run(G,feed_dict={ z: display_z, phase_train.name:False}) sdata = sdata.reshape(sdata.shape[0], 28*28) NNdiff = np.sum(np.square(np.expand_dims(sdata,axis=1) - np.expand_dims(data,axis=0)),axis=2) NN = data[np.argmin(NNdiff,axis=1)] sdata = sdata.reshape(sdata.shape[0], 28, 28, 1)/2.+0.5 NN = np.reshape(NN, [batchsize, 28, 28, 1])/2.+0.5 sdata = merge(sdata[:49],[7,7]) NN = merge(NN[:49],[7,7]) sdata = np.concatenate([sdata, NN], axis=1) sdata = np.array(sdata*255.,dtype=np.int) cv2.imwrite(results_dir + "/NN" + str(counter) + ".png", sdata)#gan_1nin_8gfdim_floss_alpha1_z15 # Plotting the latent space using tsne z_Mog = zin.eval()#display_z gen = G.eval({z:display_z, phase_train.name:False}) Y = tsne.tsne(z_Mog, 2, z_dim, 10.0); Plot.scatter(Y[:,0], Y[:,1]) xtrain = gen.copy() fig, ax = Plot.subplots() artists = [] for i, (x0, y0) in enumerate(zip(Y[:,0], Y[:,1])): image = xtrain[i%xtrain.shape[0]] image = image.reshape(28,28) im = OffsetImage(image, zoom=1.0) ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False) artists.append(ax.add_artist(ab)) ax.update_datalim(np.column_stack([Y[:,0], Y[:,1]])) ax.autoscale() Plot.scatter(Y[:,0], Y[:,1], 20); fig.savefig(results_dir + "/plot" + str(counter) + ".png") saver.save(sess,os.getcwd()+"../results/mnist/train/", global_step=counter) else: #Generating samples from a saved model saver.restore(sess,tf.train.latest_checkpoint(os.getcwd()+"../results/mnist/train/")) samples=[] for i in range(100): batch_z = np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32) sdata = sess.run(G,feed_dict={z: batch_z, phase_train.name:False}) sdata = sdata.reshape(sdata.shape[0], 28, 28, 1)/2.+0.5 sdata = sdata*255. samples.append(sdata) samples1 = np.concatenate(samples,0) np.save(results_dir + '/MNIST_samples5k.npy',samples1) print("samples saved") sys.exit()