from __future__ import print_function, division import scipy from keras.datasets import mnist from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate from keras.layers import BatchNormalization, Activation, ZeroPadding2D from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam import datetime import matplotlib.pyplot as plt import sys from data_loader import DataLoader import numpy as np import os class DiscoGAN(): def __init__(self): # Input shape self.img_rows = 128 self.img_cols = 128 self.channels = 3 self.img_shape = (self.img_rows, self.img_cols, self.channels) # Configure data loader self.dataset_name = 'edges2shoes' self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=(self.img_rows, self.img_cols)) # Calculate output shape of D (PatchGAN) patch = int(self.img_rows / 2**4) self.disc_patch = (patch, patch, 1) # Number of filters in the first layer of G and D self.gf = 64 self.df = 64 optimizer = Adam(0.0002, 0.5) # Build and compile the discriminators self.d_A = self.build_discriminator() self.d_B = self.build_discriminator() self.d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) self.d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) #------------------------- # Construct Computational # Graph of Generators #------------------------- # Build the generators self.g_AB = self.build_generator() self.g_BA = self.build_generator() # Input images from both domains img_A = Input(shape=self.img_shape) img_B = Input(shape=self.img_shape) # Translate images to the other domain fake_B = self.g_AB(img_A) fake_A = self.g_BA(img_B) # Translate images back to original domain reconstr_A = self.g_BA(fake_B) reconstr_B = self.g_AB(fake_A) # For the combined model we will only train the generators self.d_A.trainable = False self.d_B.trainable = False # Discriminators determines validity of translated images valid_A = self.d_A(fake_A) valid_B = self.d_B(fake_B) # Objectives # + Adversarial: Fool domain discriminators # + Translation: Minimize MAE between e.g. fake B and true B # + Cycle-consistency: Minimize MAE between reconstructed images and original self.combined = Model(inputs=[img_A, img_B], outputs=[ valid_A, valid_B, fake_B, fake_A, reconstr_A, reconstr_B ]) self.combined.compile(loss=['mse', 'mse', 'mae', 'mae', 'mae', 'mae'], optimizer=optimizer) def build_generator(self): """U-Net Generator""" def conv2d(layer_input, filters, f_size=4, normalize=True): """Layers used during downsampling""" d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) d = LeakyReLU(alpha=0.2)(d) if normalize: d = InstanceNormalization()(d) return d def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0): """Layers used during upsampling""" u = UpSampling2D(size=2)(layer_input) u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u) if dropout_rate: u = Dropout(dropout_rate)(u) u = InstanceNormalization()(u) u = Concatenate()([u, skip_input]) return u # Image input d0 = Input(shape=self.img_shape) # Downsampling d1 = conv2d(d0, self.gf, normalize=False) d2 = conv2d(d1, self.gf*2) d3 = conv2d(d2, self.gf*4) d4 = conv2d(d3, self.gf*8) d5 = conv2d(d4, self.gf*8) d6 = conv2d(d5, self.gf*8) d7 = conv2d(d6, self.gf*8) # Upsampling u1 = deconv2d(d7, d6, self.gf*8) u2 = deconv2d(u1, d5, self.gf*8) u3 = deconv2d(u2, d4, self.gf*8) u4 = deconv2d(u3, d3, self.gf*4) u5 = deconv2d(u4, d2, self.gf*2) u6 = deconv2d(u5, d1, self.gf) u7 = UpSampling2D(size=2)(u6) output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7) return Model(d0, output_img) def build_discriminator(self): def d_layer(layer_input, filters, f_size=4, normalization=True): """Discriminator layer""" d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) d = LeakyReLU(alpha=0.2)(d) if normalization: d = InstanceNormalization()(d) return d img = Input(shape=self.img_shape) d1 = d_layer(img, self.df, normalization=False) d2 = d_layer(d1, self.df*2) d3 = d_layer(d2, self.df*4) d4 = d_layer(d3, self.df*8) validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4) return Model(img, validity) def train(self, epochs, batch_size=128, sample_interval=50): start_time = datetime.datetime.now() # Adversarial loss ground truths valid = np.ones((batch_size,) + self.disc_patch) fake = np.zeros((batch_size,) + self.disc_patch) for epoch in range(epochs): for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)): # ---------------------- # Train Discriminators # ---------------------- # Translate images to opposite domain fake_B = self.g_AB.predict(imgs_A) fake_A = self.g_BA.predict(imgs_B) # Train the discriminators (original images = real / translated = Fake) dA_loss_real = self.d_A.train_on_batch(imgs_A, valid) dA_loss_fake = self.d_A.train_on_batch(fake_A, fake) dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake) dB_loss_real = self.d_B.train_on_batch(imgs_B, valid) dB_loss_fake = self.d_B.train_on_batch(fake_B, fake) dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake) # Total disciminator loss d_loss = 0.5 * np.add(dA_loss, dB_loss) # ------------------ # Train Generators # ------------------ # Train the generators g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, valid, \ imgs_B, imgs_A, \ imgs_A, imgs_B]) elapsed_time = datetime.datetime.now() - start_time # Plot the progress print ("[%d] [%d/%d] time: %s, [d_loss: %f, g_loss: %f]" % (epoch, batch_i, self.data_loader.n_batches, elapsed_time, d_loss[0], g_loss[0])) # If at save interval => save generated image samples if batch_i % sample_interval == 0: self.sample_images(epoch, batch_i) def sample_images(self, epoch, batch_i): os.makedirs('images/%s' % self.dataset_name, exist_ok=True) r, c = 2, 3 imgs_A, imgs_B = self.data_loader.load_data(batch_size=1, is_testing=True) # Translate images to the other domain fake_B = self.g_AB.predict(imgs_A) fake_A = self.g_BA.predict(imgs_B) # Translate back to original domain reconstr_A = self.g_BA.predict(fake_B) reconstr_B = self.g_AB.predict(fake_A) gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B]) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 titles = ['Original', 'Translated', 'Reconstructed'] fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt]) axs[i, j].set_title(titles[j]) axs[i,j].axis('off') cnt += 1 fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i)) plt.close() if __name__ == '__main__': gan = DiscoGAN() gan.train(epochs=20, batch_size=1, sample_interval=200)