#!/usr/bin/env python # -*- coding: utf-8 -*- """ After training a GAN, use gradient descent to map images back to their original position in the latent vector. """ from __future__ import print_function import argparse import keras from keras.datasets import mnist import keras.backend as K import gandlf import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation # For repeatability. np.random.seed(1337) # To make the images work correctly. keras.backend.set_image_dim_ordering('tf') def plot_as_gif(x, x_ref, title, interval=50): """Plots data as a gif. Args: x: numpy array with shape (batch_size, 28, 28, 1). interval: int, the time between frames in milliseconds. """ filename = title.lower().replace(' ', '_') save_path = '/tmp/%s.gif' % filename # Plots the first sample. fig, (ax, ax2) = plt.subplots(1, 2) im = ax.imshow(x[0].reshape((28, 28)), interpolation='none', aspect='auto', cmap='gray', animated=True) # Fixes the dimensions. ax.axis('off') def updatefig(i, *args): im.set_array(x[i].reshape((28, 28))) return im, anim = FuncAnimation(fig, updatefig, frames=np.arange(0, x.shape[0]), interval=interval) ax2.imshow(x_ref.reshape((28, 28)), interpolation='none', aspect='auto', cmap='gray') ax2.axis('off') plt.suptitle(title) anim.save(save_path, dpi=80, writer='imagemagick') print('Saved gif to "%s".' % save_path) plt.show() def build_generator(): """Builds the big generator model.""" latent = keras.layers.Input((100,), name='latent') image_class = keras.layers.Input((10,), dtype='float32', name='image_class') d = keras.layers.Dense(100)(image_class) merged = keras.layers.merge([latent, d], mode='sum') hidden = keras.layers.Dense(512)(merged) hidden = keras.layers.LeakyReLU()(hidden) hidden = keras.layers.Dense(512)(hidden) hidden = keras.layers.LeakyReLU()(hidden) output_layer = keras.layers.Dense(28 * 28, activation='tanh')(hidden) fake_image = keras.layers.Reshape((28, 28, 1))(output_layer) return keras.models.Model(input=[latent, image_class], output=fake_image) def build_discriminator(): """Builds the big discriminator model.""" image = keras.layers.Input((28, 28, 1), name='real_data') hidden = keras.layers.Flatten()(image) # First hidden layer. hidden = keras.layers.Dense(512)(hidden) hidden = keras.layers.Dropout(0.3)(hidden) hidden = keras.layers.LeakyReLU()(hidden) # Second hidden layer. hidden = keras.layers.Dense(512)(hidden) hidden = keras.layers.Dropout(0.3)(hidden) hidden = keras.layers.LeakyReLU()(hidden) # Output layer. fake = keras.layers.Dense(1, activation='sigmoid', name='s')(hidden) aux = keras.layers.Dense(10, activation='softmax', name='c')(hidden) return keras.models.Model(input=image, output=[fake, aux]) def reverse_generator(generator, X_sample, y_sample, title): """Gradient descent to map images back to their latent vectors.""" latent_vec = np.random.normal(size=(1, 100)) # Function for figuring out how to bump the input. target = K.placeholder() loss = K.sum(K.square(generator.outputs[0] - target)) grad = K.gradients(loss, generator.inputs[0])[0] update_fn = K.function(generator.inputs + [target], [grad]) # Repeatedly apply the update rule. xs = [] for i in range(60): print('%d: latent_vec mean=%f, std=%f' % (i, np.mean(latent_vec), np.std(latent_vec))) xs.append(generator.predict_on_batch([latent_vec, y_sample])) for _ in range(10): update_vec = update_fn([latent_vec, y_sample, X_sample])[0] latent_vec -= update_vec * update_rate # Plots the samples. xs = np.concatenate(xs, axis=0) plot_as_gif(xs, X_sample, title) def get_mnist_data(binarize=False): """Puts the MNIST data in the right format.""" (X_train, y_train), (X_test, y_test) = mnist.load_data() if binarize: X_test = np.where(X_test >= 10, 1, -1) X_train = np.where(X_train >= 10, 1, -1) else: X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_test = (X_test.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=-1) X_test = np.expand_dims(X_test, axis=-1) y_train = np.eye(10)[y_train] y_test = np.eye(10)[y_test] return (X_train, y_train), (X_test, y_test) if __name__ == '__main__': update_rate = 0.1 generator = build_generator() (X_train, y_train), (X_test, y_test) = get_mnist_data(binarize=False) X_sample = X_test[5:6] y_sample = y_test[5:6] # Plot samples before training. reverse_generator( generator, X_sample, y_sample, 'Before training generator') # Trains GAN. discriminator = build_discriminator() model = gandlf.Model(generator, discriminator) loss_weights = {'s': 1., 'c': 1., 'c_fake': 0.} model.compile(optimizer=['adam', 'sgd'], loss=['binary_crossentropy', 'categorical_crossentropy'], loss_weights=loss_weights) model.fit(['normal', y_train, X_train], {'s': 'ones', 's_fake': 'zeros', 'c': y_train}, nb_epoch=1, batch_size=32) # Plot samples after training. reverse_generator( generator, X_sample, y_sample, 'After training generator')