import matplotlib as mpl # This line allows mpl to run with no DISPLAY defined mpl.use('Agg') from keras.layers import Dense, Flatten, Input, merge, Dropout from keras.models import Model from keras.optimizers import Adam from keras_adversarial.legacy import l1l2 import keras.backend as K import pandas as pd import numpy as np from keras_adversarial.image_grid_callback import ImageGridCallback from keras_adversarial import AdversarialModel, gan_targets, fix_names, n_choice, simple_bigan from keras_adversarial import AdversarialOptimizerSimultaneous, normal_latent_sampling from mnist_utils import mnist_data from example_gan import model_generator from keras.layers import BatchNormalization, LeakyReLU import os def model_encoder(latent_dim, input_shape, hidden_dim=1024, reg=lambda: l1l2(1e-5, 0), batch_norm_mode=0): x = Input(input_shape, name="x") h = Flatten()(x) h = Dense(hidden_dim, name="encoder_h1", W_regularizer=reg())(h) h = BatchNormalization(mode=batch_norm_mode)(h) h = LeakyReLU(0.2)(h) h = Dense(hidden_dim / 2, name="encoder_h2", W_regularizer=reg())(h) h = BatchNormalization(mode=batch_norm_mode)(h) h = LeakyReLU(0.2)(h) h = Dense(hidden_dim / 4, name="encoder_h3", W_regularizer=reg())(h) h = BatchNormalization(mode=batch_norm_mode)(h) h = LeakyReLU(0.2)(h) mu = Dense(latent_dim, name="encoder_mu", W_regularizer=reg())(h) log_sigma_sq = Dense(latent_dim, name="encoder_log_sigma_sq", W_regularizer=reg())(h) z = merge([mu, log_sigma_sq], mode=lambda p: p[0] + K.random_normal(K.shape(p[0])) * K.exp(p[1] / 2), output_shape=lambda x: x[0]) return Model(x, z, name="encoder") def model_discriminator(latent_dim, input_shape, output_dim=1, hidden_dim=2048, reg=lambda: l1l2(1e-7, 1e-7), batch_norm_mode=1, dropout=0.5): z = Input((latent_dim,)) x = Input(input_shape, name="x") h = merge([z, Flatten()(x)], mode='concat') h1 = Dense(hidden_dim, name="discriminator_h1", W_regularizer=reg()) b1 = BatchNormalization(mode=batch_norm_mode) h2 = Dense(hidden_dim, name="discriminator_h2", W_regularizer=reg()) b2 = BatchNormalization(mode=batch_norm_mode) h3 = Dense(hidden_dim, name="discriminator_h3", W_regularizer=reg()) b3 = BatchNormalization(mode=batch_norm_mode) y = Dense(output_dim, name="discriminator_y", activation="sigmoid", W_regularizer=reg()) # training model uses dropout _h = h _h = Dropout(dropout)(LeakyReLU(0.2)((b1(h1(_h))))) _h = Dropout(dropout)(LeakyReLU(0.2)((b2(h2(_h))))) _h = Dropout(dropout)(LeakyReLU(0.2)((b3(h3(_h))))) ytrain = y(_h) mtrain = Model([z, x], ytrain, name="discriminator_train") # testing model does not use dropout _h = h _h = LeakyReLU(0.2)((b1(h1(_h)))) _h = LeakyReLU(0.2)((b2(h2(_h)))) _h = LeakyReLU(0.2)((b3(h3(_h)))) ytest = y(_h) mtest = Model([z, x], ytest, name="discriminator_test") return mtrain, mtest def example_bigan(path, adversarial_optimizer): # z \in R^100 latent_dim = 25 # x \in R^{28x28} input_shape = (28, 28) # generator (z -> x) generator = model_generator(latent_dim, input_shape) # encoder (x ->z) encoder = model_encoder(latent_dim, input_shape) # autoencoder (x -> x') autoencoder = Model(encoder.inputs, generator(encoder(encoder.inputs))) # discriminator (x -> y) discriminator_train, discriminator_test = model_discriminator(latent_dim, input_shape) # bigan (z, x - > yfake, yreal) bigan_generator = simple_bigan(generator, encoder, discriminator_test) bigan_discriminator = simple_bigan(generator, encoder, discriminator_train) # z generated on GPU based on batch dimension of x x = bigan_generator.inputs[1] z = normal_latent_sampling((latent_dim,))(x) # eliminate z from inputs bigan_generator = Model([x], fix_names(bigan_generator([z, x]), bigan_generator.output_names)) bigan_discriminator = Model([x], fix_names(bigan_discriminator([z, x]), bigan_discriminator.output_names)) generative_params = generator.trainable_weights + encoder.trainable_weights # print summary of models generator.summary() encoder.summary() discriminator_train.summary() bigan_discriminator.summary() autoencoder.summary() # build adversarial model model = AdversarialModel(player_models=[bigan_generator, bigan_discriminator], player_params=[generative_params, discriminator_train.trainable_weights], player_names=["generator", "discriminator"]) model.adversarial_compile(adversarial_optimizer=adversarial_optimizer, player_optimizers=[Adam(1e-4, decay=1e-4), Adam(1e-3, decay=1e-4)], loss='binary_crossentropy') # load mnist data xtrain, xtest = mnist_data() # callback for image grid of generated samples def generator_sampler(): zsamples = np.random.normal(size=(10 * 10, latent_dim)) return generator.predict(zsamples).reshape((10, 10, 28, 28)) generator_cb = ImageGridCallback(os.path.join(path, "generated-epoch-{:03d}.png"), generator_sampler) # callback for image grid of autoencoded samples def autoencoder_sampler(): xsamples = n_choice(xtest, 10) xrep = np.repeat(xsamples, 9, axis=0) xgen = autoencoder.predict(xrep).reshape((10, 9, 28, 28)) xsamples = xsamples.reshape((10, 1, 28, 28)) x = np.concatenate((xsamples, xgen), axis=1) return x autoencoder_cb = ImageGridCallback(os.path.join(path, "autoencoded-epoch-{:03d}.png"), autoencoder_sampler) # train network y = gan_targets(xtrain.shape[0]) ytest = gan_targets(xtest.shape[0]) history = model.fit(x=xtrain, y=y, validation_data=(xtest, ytest), callbacks=[generator_cb, autoencoder_cb], nb_epoch=100, batch_size=32) # save history df = pd.DataFrame(history.history) df.to_csv(os.path.join(path, "history.csv")) # save model encoder.save(os.path.join(path, "encoder.h5")) generator.save(os.path.join(path, "generator.h5")) discriminator_train.save(os.path.join(path, "discriminator.h5")) def main(): example_bigan("output/bigan", AdversarialOptimizerSimultaneous()) if __name__ == "__main__": main()