#!/usr/bin/env python3 import os import sys import numpy as np import random from keras import models from keras import optimizers from keras.layers import Input from keras.optimizers import Adam, Adagrad, Adadelta, Adamax, SGD from keras.callbacks import CSVLogger import scipy import h5py from args import Args from data import denormalize4gan #from layers import bilinear2x from discrimination import MinibatchDiscrimination from nets import build_discriminator, build_gen, build_enc #import tensorflow as tf #import keras #keras.backend.get_session().run(tf.initialize_all_variables()) def sample_faces( faces ): reals = [] for i in range( Args.batch_sz ) : j = random.randrange( len(faces) ) face = faces[ j ] reals.append( face ) reals = np.array(reals) return reals def binary_noise(cnt): # Distribution of noise matters. # If you use single ranf that spans [0, 1], training will not work. # Well, for me at least. # Either normal or ranf works for me but be sure to use them with randrange(2) or something. #noise = np.random.normal( scale=Args.label_noise, size=((Args.batch_sz,) + Args.noise_shape) ) # Note about noise rangel. # 0, 1 noise vs -1, 1 noise. -1, 1 seems to be better and stable. noise = Args.label_noise * np.random.ranf((cnt,) + Args.noise_shape) # [0, 0.1] noise -= 0.05 # [-0.05, 0.05] noise += np.random.randint(0, 2, size=((cnt,) + Args.noise_shape)) noise -= 0.5 noise *= 2 return noise def sample_fake( gen ) : noise = binary_noise(Args.batch_sz) fakes = gen.predict(noise) return fakes, noise def dump_batch(imgs, cnt, ofname): ''' Merges cnt x cnt generated images into one big image. Use the command $ feh dump.png --reload 1 to refresh image peroidically during training! ''' assert Args.batch_sz >= cnt * cnt rows = [] for i in range( cnt ) : cols = [] for j in range(cnt*i, cnt*i+cnt): cols.append( imgs[j] ) rows.append( np.concatenate(cols, axis=1) ) alles = np.concatenate( rows, axis=0 ) alles = denormalize4gan( alles ) #alles = scipy.misc.imresize(alles, 200) # uncomment to scale scipy.misc.imsave( ofname, alles ) def build_networks(): shape = (Args.sz, Args.sz, 3) # Learning rate is important. # Optimizers are important too, try experimenting them yourself to fit your dataset. # I recommend you read DCGAN paper. # Unlike gan hacks, sgd doesn't seem to work well. # DCGAN paper states that they used Adam for both G and D. #opt = optimizers.SGD(lr=0.0001, decay=0.0, momentum=0.9, nesterov=True) #dopt = optimizers.SGD(lr=0.0001, decay=0.0, momentum=0.9, nesterov=True) # lr=0.010. Looks good, statistically (low d loss, higher g loss) # but too much for the G to create face. # If you see only one color 'flood fill' during training for about 10 batches or so, # training is failing. If you see only a few colors (instead of colorful noise) # then lr is too high for the opt and G will not have chance to form face. #dopt = Adam(lr=0.010, beta_1=0.5) #opt = Adam(lr=0.001, beta_1=0.5) # vague faces @ 500 # Still can't get higher frequency component. #dopt = Adam(lr=0.0010, beta_1=0.5) #opt = Adam(lr=0.0001, beta_1=0.5) # better faces @ 500 # but mode collapse after that, probably due to learning rate being too high. # opt.lr = dopt.lr / 10 works nicely. I found this with trial and error. # now same lr, as we are using history to train D multiple times. # I don't exactly understand how decay parameter in Adam works. Certainly not exponential. # Actually faster than exponential, when I look at the code and plot it in Excel. dopt = Adam(lr=0.0002, beta_1=Args.adam_beta) opt = Adam(lr=0.0001, beta_1=Args.adam_beta) # too slow # Another thing about LR. # If you make it small, it will only optimize slowly. # LR only has to be smaller than certain threshold that is data dependent. # (related to the largest gradient that prevents optimization) #dopt = Adam(lr=0.000010, beta_1=0.5) #opt = Adam(lr=0.000001, beta_1=0.5) # generator part gen = build_gen( shape ) # loss function doesn't seem to matter for this one, as it is not directly trained gen.compile(optimizer=opt, loss='binary_crossentropy') gen.summary() # discriminator part disc = build_discriminator( shape ) disc.compile(optimizer=dopt, loss='binary_crossentropy') disc.summary() # GAN stack # https://ctmakro.github.io/site/on_learning/fast_gan_in_keras.html is the faster way. # Here, for simplicity, I use slower way (slower due to duplicate computation). noise = Input( shape=Args.noise_shape ) gened = gen( noise ) result = disc( gened ) gan = models.Model( inputs=noise, outputs=result ) gan.compile(optimizer=opt, loss='binary_crossentropy') gan.summary() return gen, disc, gan def train_autoenc( dataf ): ''' Train an autoencoder first to see if your network is large enough. ''' f = h5py.File( dataf, 'r' ) faces = f.get( 'faces' ) opt = Adam(lr=0.001) shape = (Args.sz, Args.sz, 3) enc = build_enc( shape ) enc.compile(optimizer=opt, loss='mse') enc.summary() # generator part gen = build_gen( shape ) # generator is not directly trained. Optimizer and loss doesn't matter too much. gen.compile(optimizer=opt, loss='mse') gen.summary() face = Input( shape=shape ) vector = enc(face) recons = gen(vector) autoenc = models.Model( inputs=face, outputs=recons ) autoenc.compile(optimizer=opt, loss='mse') epoch = 0 while epoch < 200 : for i in range(10) : reals = sample_faces( faces ) fakes, noises = sample_fake( gen ) loss = autoenc.train_on_batch( reals, reals ) epoch += 1 print(epoch, loss) fakes = autoenc.predict(reals) dump_batch(fakes, 4, "fakes.png") dump_batch(reals, 4, "reals.png") gen.save_weights(Args.genw) enc.save_weights(Args.discw) print("Saved", Args.genw, Args.discw) def load_weights(model, wf): ''' I find error message in load_weights hard to understand sometimes. ''' try: model.load_weights(wf) except: print("failed to load weight, network changed or corrupt hdf5", wf, file=sys.stderr) sys.exit(1) def train_gan( dataf ) : gen, disc, gan = build_networks() # Uncomment these, if you want to continue training from some snapshot. # (or load pretrained generator weights) #load_weights(gen, Args.genw) #load_weights(disc, Args.discw) logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently logger.on_train_begin() # initialize csv file with h5py.File( dataf, 'r' ) as f : faces = f.get( 'faces' ) run_batches(gen, disc, gan, faces, logger, range(5000)) logger.on_train_end() def run_batches(gen, disc, gan, faces, logger, itr_generator): history = [] # need this to prevent G from shifting from mode to mode to trick D. train_disc = True for batch in itr_generator: # Using soft labels here. lbl_fake = Args.label_noise * np.random.ranf(Args.batch_sz) lbl_real = 1 - Args.label_noise * np.random.ranf(Args.batch_sz) fakes, noises = sample_fake( gen ) reals = sample_faces( faces ) # Add noise... # My dataset works without this. #reals += 0.5 * np.exp(-batch/100) * np.random.normal( size=reals.shape ) if batch % 10 == 0 : if len(history) > Args.history_sz: history.pop(0) # evict oldest history.append( (reals, fakes) ) gen.trainable = False #for reals, fakes in history: d_loss1 = disc.train_on_batch( reals, lbl_real ) d_loss0 = disc.train_on_batch( fakes, lbl_fake ) gen.trainable = True #if d_loss1 > 15.0 or d_loss0 > 15.0 : # artificial training of one of G or D based on # statistics is not good at all. # pretrain train discriminator only if batch < 20 : print( batch, "d0:{} d1:{}".format( d_loss0, d_loss1 ) ) continue disc.trainable = False g_loss = gan.train_on_batch( noises, lbl_real ) # try to trick the classifier. disc.trainable = True # To escape this loop, both D and G should be trained so that # D begins to mark everything that's wrong that G has done. # Otherwise G will only change locally and fail to escape the minima. #train_disc = True if g_loss < 15 else False print( batch, "d0:{} d1:{} g:{}".format( d_loss0, d_loss1, g_loss ) ) # save weights every 10 batches if batch % 10 == 0 and batch != 0 : end_of_batch_task(batch, gen, disc, reals, fakes) row = {"d_loss0": d_loss0, "d_loss1": d_loss1, "g_loss": g_loss} logger.on_epoch_end(batch, row) _bits = binary_noise(Args.batch_sz) def end_of_batch_task(batch, gen, disc, reals, fakes): try : # Dump how the generator is doing. # Animation dump dump_batch(reals, 4, "reals.png") dump_batch(fakes, 4, "fakes.png") # to check how noisy the image is frame = gen.predict(_bits) animf = os.path.join(Args.anim_dir, "frame_{:08d}.png".format(int(batch/10))) dump_batch(frame, 4, animf) dump_batch(frame, 4, "frame.png") serial = int(batch / 10) % 10 prefix = os.path.join(Args.snapshot_dir, str(serial) + ".") print("Saving weights", serial) gen.save_weights(prefix + Args.genw) disc.save_weights(prefix + Args.discw) except KeyboardInterrupt : print("Saving, don't interrupt with Ctrl+C!", serial) # recursion to surely save everything haha end_of_batch_task(batch, gen, disc, reals, fakes) raise def generate( genw, cnt ): shape = (Args.sz, Args.sz, 3) gen = build_gen( shape ) gen.compile(optimizer='sgd', loss='mse') load_weights(gen, Args.genw) generated = gen.predict(binary_noise(Args.batch_sz)) # Unoffset, in batch. # Must convert back to unit8 to stop color distortion. generated = denormalize4gan(generated) for i in range(cnt): ofname = "{:04d}.png".format(i) scipy.misc.imsave( ofname, generated[i] ) def main( argv ) : if not os.path.exists(Args.snapshot_dir) : os.mkdir(Args.snapshot_dir) if not os.path.exists(Args.anim_dir) : os.mkdir(Args.anim_dir) # test the capability of generator network through autoencoder test. # The argument is that if the generator network can memorize the inputs then # it should be enough to GAN-generate stuff. # Pretraining gen isn't that useful in gan training as # the untrained discriminator will soon ruin everything. #train_autoenc( "data.hdf5" ) # train GAN with inputs in data.hdf5 train_gan( "data.hdf5" ) # Lets generate stuff #generate( "gen.hdf5", 256 ) if __name__ == '__main__': main(sys.argv)