from __future__ import print_function, division import scipy from keras.datasets import mnist from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import Adam from keras.utils import to_categorical import datetime import matplotlib.pyplot as plt import sys from data_loader import DataLoader import numpy as np import os class PixelDA(): def __init__(self): # Input shape self.img_rows = 32 self.img_cols = 32 self.channels = 3 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.num_classes = 10 # Configure MNIST and MNIST-M data loader self.data_loader = DataLoader(img_res=(self.img_rows, self.img_cols)) # Loss weights lambda_adv = 10 lambda_clf = 1 # Calculate output shape of D (PatchGAN) patch = int(self.img_rows / 2**4) self.disc_patch = (patch, patch, 1) # Number of residual blocks in the generator self.residual_blocks = 6 optimizer = Adam(0.0002, 0.5) # Number of filters in first layer of discriminator and classifier self.df = 64 self.cf = 64 # Build and compile the discriminators self.discriminator = self.build_discriminator() self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # Build the task (classification) network self.clf = self.build_classifier() # Input images from both domains img_A = Input(shape=self.img_shape) img_B = Input(shape=self.img_shape) # Translate images from domain A to domain B fake_B = self.generator(img_A) # Classify the translated image class_pred = self.clf(fake_B) # For the combined model we will only train the generator and classifier self.discriminator.trainable = False # Discriminator determines validity of translated images valid = self.discriminator(fake_B) self.combined = Model(img_A, [valid, class_pred]) self.combined.compile(loss=['mse', 'categorical_crossentropy'], loss_weights=[lambda_adv, lambda_clf], optimizer=optimizer, metrics=['accuracy']) def build_generator(self): """Resnet Generator""" def residual_block(layer_input): """Residual block described in paper""" d = Conv2D(64, kernel_size=3, strides=1, padding='same')(layer_input) d = BatchNormalization(momentum=0.8)(d) d = Activation('relu')(d) d = Conv2D(64, kernel_size=3, strides=1, padding='same')(d) d = BatchNormalization(momentum=0.8)(d) d = Add()([d, layer_input]) return d # Image input img = Input(shape=self.img_shape) l1 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(img) # Propogate signal through residual blocks r = residual_block(l1) for _ in range(self.residual_blocks - 1): r = residual_block(r) output_img = Conv2D(self.channels, kernel_size=3, padding='same', activation='tanh')(r) return Model(img, output_img) def build_discriminator(self): def d_layer(layer_input, filters, f_size=4, normalization=True): """Discriminator layer""" d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) d = LeakyReLU(alpha=0.2)(d) if normalization: d = InstanceNormalization()(d) return d img = Input(shape=self.img_shape) d1 = d_layer(img, self.df, normalization=False) d2 = d_layer(d1, self.df*2) d3 = d_layer(d2, self.df*4) d4 = d_layer(d3, self.df*8) validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4) return Model(img, validity) def build_classifier(self): def clf_layer(layer_input, filters, f_size=4, normalization=True): """Classifier layer""" d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input) d = LeakyReLU(alpha=0.2)(d) if normalization: d = InstanceNormalization()(d) return d img = Input(shape=self.img_shape) c1 = clf_layer(img, self.cf, normalization=False) c2 = clf_layer(c1, self.cf*2) c3 = clf_layer(c2, self.cf*4) c4 = clf_layer(c3, self.cf*8) c5 = clf_layer(c4, self.cf*8) class_pred = Dense(self.num_classes, activation='softmax')(Flatten()(c5)) return Model(img, class_pred) def train(self, epochs, batch_size=128, sample_interval=50): half_batch = int(batch_size / 2) # Classification accuracy on 100 last batches of domain B test_accs = [] # Adversarial ground truths valid = np.ones((batch_size, *self.disc_patch)) fake = np.zeros((batch_size, *self.disc_patch)) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- imgs_A, labels_A = self.data_loader.load_data(domain="A", batch_size=batch_size) imgs_B, labels_B = self.data_loader.load_data(domain="B", batch_size=batch_size) # Translate images from domain A to domain B fake_B = self.generator.predict(imgs_A) # Train the discriminators (original images = real / translated = Fake) d_loss_real = self.discriminator.train_on_batch(imgs_B, valid) d_loss_fake = self.discriminator.train_on_batch(fake_B, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # -------------------------------- # Train Generator and Classifier # -------------------------------- # One-hot encoding of labels labels_A = to_categorical(labels_A, num_classes=self.num_classes) # Train the generator and classifier g_loss = self.combined.train_on_batch(imgs_A, [valid, labels_A]) #----------------------- # Evaluation (domain B) #----------------------- pred_B = self.clf.predict(imgs_B) test_acc = np.mean(np.argmax(pred_B, axis=1) == labels_B) # Add accuracy to list of last 100 accuracy measurements test_accs.append(test_acc) if len(test_accs) > 100: test_accs.pop(0) # Plot the progress print ( "%d : [D - loss: %.5f, acc: %3d%%], [G - loss: %.5f], [clf - loss: %.5f, acc: %3d%%, test_acc: %3d%% (%3d%%)]" % \ (epoch, d_loss[0], 100*float(d_loss[1]), g_loss[1], g_loss[2], 100*float(g_loss[-1]), 100*float(test_acc), 100*float(np.mean(test_accs)))) # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch) def sample_images(self, epoch): r, c = 2, 5 imgs_A, _ = self.data_loader.load_data(domain="A", batch_size=5) # Translate images to the other domain fake_B = self.generator.predict(imgs_A) gen_imgs = np.concatenate([imgs_A, fake_B]) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 #titles = ['Original', 'Translated'] fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt]) #axs[i, j].set_title(titles[i]) axs[i,j].axis('off') cnt += 1 fig.savefig("images/%d.png" % (epoch)) plt.close() if __name__ == '__main__': gan = PixelDA() gan.train(epochs=30000, batch_size=32, sample_interval=500)