import os import time import tensorflow as tf import numpy as np from tqdm import trange from math import ceil, log from numpy import sin, cos from scipy.misc import imsave NO_REDUCTION = tf.losses.Reduction.NONE class Trainer(): def __init__(self, sess, config, real_images, g_builder, d_builder, cp_builder, zp_builder, coord_handler, patch_handler): self.sess = sess self.config = config self.real_images = real_images self.g_builder = g_builder self.d_builder = d_builder self.cp_builder = cp_builder self.zp_builder = zp_builder self.coord_handler = coord_handler self.patch_handler = patch_handler # Vars for graph building self.batch_size = self.config["train_params"]["batch_size"] self.z_dim = self.config["model_params"]["z_dim"] self.spatial_dim = self.config["model_params"]["spatial_dim"] self.micro_patch_size = self.config["data_params"]["micro_patch_size"] self.macro_patch_size = self.config["data_params"]["macro_patch_size"] self.ratio_macro_to_micro = self.config["data_params"]["ratio_macro_to_micro"] self.ratio_full_to_micro = self.config["data_params"]["ratio_full_to_micro"] self.num_micro_compose_macro = self.config["data_params"]["num_micro_compose_macro"] # Vars for training loop self.exp_name = config["log_params"]["exp_name"] self.epochs = float(self.config["train_params"]["epochs"]) self.num_batches = self.config["data_params"]["num_train_samples"] // self.batch_size self.coordinate_system = self.config["data_params"]["coordinate_system"] self.G_update_period = self.config["train_params"]["G_update_period"] self.D_update_period = self.config["train_params"]["D_update_period"] self.Q_update_period = self.config["train_params"]["Q_update_period"] # Loss weights self.code_loss_w = self.config["loss_params"]["code_loss_w"] self.coord_loss_w = self.config["loss_params"]["coord_loss_w"] self.gp_lambda = self.config["loss_params"]["gp_lambda"] # Extrapolation parameters handling self.train_extrap = self.config["train_params"]["train_extrap"] if self.train_extrap: assert self.config["train_params"]["num_extrap_steps"] is not None assert self.coordinate_system is not "euclidean", \ "I didn't handle extrapolation in {} coordinate system!".format(self.coordinate_system) self.num_extrap_steps = self.config["train_params"]["num_extrap_steps"] else: self.num_extrap_steps = 0 def _train_content_prediction_model(self): return (self.Q_update_period>0) and (self.config["train_params"]["qlr"]>0) def sample_prior(self): return np.random.uniform(-1., 1., [self.batch_size, self.z_dim]).astype(np.float32) def _dup_z_for_macro(self, z): # Duplicate with nearest neighbor, different to `tf.tile`. # E.g., # tensor: [[1, 2], [3, 4]] # repeat: 3 # output: [[1, 2], [1, 2], [1, 2], [3, 4], [3, 4], [3, 4]] ch = z.shape[-1] repeat = self.num_micro_compose_macro extend = tf.expand_dims(z, 1) extend_dup = tf.tile(extend, [1, repeat, 1]) return tf.reshape(extend_dup, [-1, ch]) def build_graph(self, test_mode=False): # Input nodes # Note: the input node name was wrong in the checkpoint self.micro_coord_fake = tf.placeholder(tf.float32, [None, self.spatial_dim], name='micro_coord_fake') self.macro_coord_fake = tf.placeholder(tf.float32, [None, self.spatial_dim], name='macro_coord_fake') self.micro_coord_real = tf.placeholder(tf.float32, [None, self.spatial_dim], name='micro_coord_real') self.macro_coord_real = tf.placeholder(tf.float32, [None, self.spatial_dim], name='macro_coord_real') # Reversing angle for cylindrical coordinate is complicated, directly pass values here self.y_angle_ratio = tf.placeholder(tf.float32, [None, 1], name='y_angle_ratio') self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z') # Crop real micro for visualization if self.coordinate_system == "euclidean": self.real_micro = self.patch_handler.crop_micro_from_full_gpu( self.real_images, self.micro_coord_real[:, 0:1], self.micro_coord_real[:, 1:2]) elif self.coordinate_system == "cylindrical": self.real_micro = self.patch_handler.crop_micro_from_full_gpu( self.real_images, self.micro_coord_real[:, 0:1], self.y_angle_ratio) # Real part self.real_macro = self.patch_handler.concat_micro_patches_gpu( self.real_micro, ratio_over_micro=self.ratio_macro_to_micro) (self.disc_real, disc_real_h) = self.d_builder(self.real_macro, self.macro_coord_real, is_training=True) self.c_real_pred = self.cp_builder(disc_real_h, is_training=True) self.z_real_pred = self.zp_builder(disc_real_h, is_training=True) # Fake part z_dup_macro = self._dup_z_for_macro(self.z) self.gen_micro = self.g_builder(z_dup_macro, self.micro_coord_fake, is_training=True) self.gen_macro = self.patch_handler.concat_micro_patches_gpu( self.gen_micro, ratio_over_micro=self.ratio_macro_to_micro) (self.disc_fake, disc_fake_h) = self.d_builder(self.gen_macro, self.macro_coord_fake, is_training=True) self.c_fake_pred = self.cp_builder(disc_fake_h, is_training=True) self.z_fake_pred = self.zp_builder(disc_fake_h, is_training=True) # Testing graph if self.config["log_params"]["merge_micro_patches_in_cpu"]: self.gen_micro_test = self.g_builder(self.z, self.micro_coord_fake, is_training=False) else: (self.gen_micro_test, self.gen_full_test) = self.generate_full_image_gpu(self.z) # Patch-Guided Image Generation graph if self._train_content_prediction_model(): (_, disc_real_h_rec) = self.d_builder(self.real_macro, None, is_training=False) estim_z = self.zp_builder(disc_real_h_rec, is_training=False) # I didn't especially handle this. # if self.config["log_params"]["merge_micro_patches_in_cpu"]: (_, self.rec_full) = self.generate_full_image_gpu(self.z) # Building these are time consuming if not test_mode: print(" [Build] Composing Loss Functions ") self._compose_losses() print(" [Build] Creating Optimizers ") self._create_optimizers() def _calc_gradient_penalty(self): """ Gradient Penalty for patches D """ # This is borrowed from https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb alpha = tf.random_uniform(shape=tf.shape(self.real_macro), minval=0.,maxval=1.) differences = self.gen_macro - self.real_macro # This is different from MAGAN interpolates = self.real_macro + (alpha * differences) disc_inter, _ = self.d_builder(interpolates, None, is_training=True) gradients = tf.gradients(disc_inter, [interpolates])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) return gradient_penalty, slopes def _compose_losses(self): # Content consistency loss self.code_loss = tf.reduce_mean(self.code_loss_w * tf.losses.absolute_difference(self.z, self.z_fake_pred)) # Spatial consistency loss (reduce later) self.coord_mse_real = self.coord_loss_w * tf.losses.mean_squared_error(self.macro_coord_real, self.c_real_pred, reduction=NO_REDUCTION) self.coord_mse_fake = self.coord_loss_w * tf.losses.mean_squared_error(self.macro_coord_fake, self.c_fake_pred, reduction=NO_REDUCTION) # (For extrapolation training) Mask-out out-of-bound (OOB) coordinate loss since the gradients are useless if self.train_extrap: upper_bound = tf.ones([self.batch_size, self.spatial_dim], tf.float32) + 1e-4 lower_bound = - upper_bound exceed_upper_bound = tf.greater(self.macro_coord_fake, upper_bound) exceed_lower_bound = tf.less(self.macro_coord_fake, lower_bound) oob_mask_sep = tf.math.logical_or(exceed_upper_bound, exceed_lower_bound) oob_mask_merge = tf.math.logical_or(oob_mask_sep[:, 0], oob_mask_sep[:, 1]) for i in range(2, self.spatial_dim): oob_mask_merge = tf.math.logical_or(oob_mask_merge, oob_mask_sep[:, i]) oob_mask = tf.tile(tf.expand_dims(oob_mask_merge, 1), [1, self.spatial_dim]) self.coord_mse_fake = tf.where(oob_mask, tf.stop_gradient(self.coord_mse_fake), self.coord_mse_fake) self.coord_mse_real = tf.reduce_mean(self.coord_mse_real) self.coord_mse_fake = tf.reduce_mean(self.coord_mse_fake) self.coord_loss = self.coord_mse_real + self.coord_mse_fake # WGAN loss self.adv_real = - tf.reduce_mean(self.disc_real) self.adv_fake = tf.reduce_mean(self.disc_fake) self.d_adv_loss = self.adv_real + self.adv_fake self.g_adv_loss = - self.adv_fake # Gradient penalty loss of WGAN-GP gradient_penalty, self.gp_slopes = self._calc_gradient_penalty() self.gp_loss = self.config["loss_params"]["gp_lambda"] * gradient_penalty # Total loss self.d_loss = self.d_adv_loss + self.gp_loss + self.coord_loss + self.code_loss self.g_loss = self.g_adv_loss + self.coord_loss + self.code_loss self.q_loss = self.g_adv_loss + self.code_loss # Wasserstein distance for visualization self.w_dist = - self.adv_real - self.adv_fake def _create_optimizers(self): t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if 'D' in var.name] g_vars = [var for var in t_vars if 'G' in var.name] q_vars = [var for var in t_vars if 'Q' in var.name] # optimizers G_update_ops = tf.get_collection(self.g_builder.update_collection) D_update_ops = tf.get_collection(self.d_builder.update_collection) Q_update_ops = tf.get_collection(self.zp_builder.update_collection) GD_update_ops = tf.get_collection(self.cp_builder.update_collection) with tf.control_dependencies(G_update_ops + GD_update_ops): self.g_optim = tf.train.AdamOptimizer( self.config["train_params"]["glr"], beta1=self.config["train_params"]["beta1"], beta2=self.config["train_params"]["beta2"], ).minimize(self.g_loss, var_list=g_vars) with tf.control_dependencies(D_update_ops + GD_update_ops): self.d_optim = tf.train.AdamOptimizer( self.config["train_params"]["dlr"], beta1=self.config["train_params"]["beta1"], beta2=self.config["train_params"]["beta2"], ).minimize(self.d_loss, var_list=d_vars) if self._train_content_prediction_model(): with tf.control_dependencies(Q_update_ops): self.q_optim = tf.train.AdamOptimizer( self.config["train_params"]["qlr"], beta1=self.config["train_params"]["beta1"], beta2=self.config["train_params"]["beta2"], ).minimize(self.q_loss, var_list=q_vars) if self.train_extrap: with tf.variable_scope("extrap_optim"): g_vars_partial = [ var for var in g_vars if ("g_resblock_0" in var.name or "g_resblock_1" in var.name)] with tf.control_dependencies(G_update_ops + GD_update_ops): self.g_optim_extrap = tf.train.AdamOptimizer( self.config["train_params"]["glr"], beta1=self.config["train_params"]["beta1"], beta2=self.config["train_params"]["beta2"], ).minimize(self.g_loss, var_list=g_vars_partial) with tf.control_dependencies(D_update_ops + GD_update_ops): self.d_optim_extrap = tf.train.AdamOptimizer( self.config["train_params"]["dlr"], beta1=self.config["train_params"]["beta1"], beta2=self.config["train_params"]["beta2"], ).minimize(self.d_loss, var_list=d_vars) def rand_sample_full_test(self): if self.config["log_params"]["merge_micro_patches_in_cpu"]: z = self.sample_prior() _, full_images = self.generate_full_image_cpu(z) else: full_images = self.sess.run( self.gen_full_test, feed_dict={self.z: self.sample_prior()}) return full_images def generate_full_image_gpu(self, z): all_micro_patches = [] all_micro_coord = [] num_patches_x = self.ratio_full_to_micro[0] + self.num_extrap_steps*2 num_patches_y = self.ratio_full_to_micro[1] + self.num_extrap_steps*2 for yy in range(num_patches_y): for xx in range(num_patches_x): if self.coordinate_system == "euclidean": micro_coord_single = tf.constant([ self.coord_handler.euclidean_coord_int_full_to_float_micro(xx, num_patches_x, extrap_steps=self.num_extrap_steps), self.coord_handler.euclidean_coord_int_full_to_float_micro(yy, num_patches_y, extrap_steps=self.num_extrap_steps), ]) elif self.coordinate_system == "cylindrical": theta_ratio = self.coord_handler.hyperbolic_coord_int_full_to_float_micro(yy, num_patches_y) micro_coord_single = tf.constant([ self.coord_handler.euclidean_coord_int_full_to_float_micro(xx, num_patches_x), self.coord_handler.hyperbolic_theta_to_euclidean(theta_ratio, proj_func=cos), self.coord_handler.hyperbolic_theta_to_euclidean(theta_ratio, proj_func=sin), ]) micro_coord = tf.tile(tf.expand_dims(micro_coord_single, 0), [tf.shape(z)[0], 1]) generated_patch = self.g_builder(z, micro_coord, is_training=False) all_micro_patches.append(generated_patch) all_micro_coord.append(micro_coord) num_patches = num_patches_x * num_patches_y all_micro_patches = tf.concat(all_micro_patches, 0) all_micro_patches_reord = self.patch_handler.reord_patches_gpu(all_micro_patches, self.batch_size, num_patches) full_image = self.patch_handler.concat_micro_patches_gpu( all_micro_patches_reord, ratio_over_micro=[num_patches_x, num_patches_y]) return all_micro_patches, full_image def generate_full_image_cpu(self, z): all_micro_patches = [] all_micro_coord = [] num_patches_x = self.ratio_full_to_micro[0] + self.num_extrap_steps * 2 num_patches_y = self.ratio_full_to_micro[1] + self.num_extrap_steps * 2 for yy in range(num_patches_y): for xx in range(num_patches_x): if self.coordinate_system == "euclidean": micro_coord_single = np.array([ self.coord_handler.euclidean_coord_int_full_to_float_micro(xx, num_patches_x, extrap_steps=self.num_extrap_steps), self.coord_handler.euclidean_coord_int_full_to_float_micro(yy, num_patches_y, extrap_steps=self.num_extrap_steps), ]) elif self.coordinate_system == "cylindrical": theta_ratio = self.coord_handler.hyperbolic_coord_int_full_to_float_micro(yy, num_patches_y) micro_coord_single = np.array([ self.coord_handler.euclidean_coord_int_full_to_float_micro(xx, num_patches_x), self.coord_handler.hyperbolic_theta_to_euclidean(theta_ratio, proj_func=cos), self.coord_handler.hyperbolic_theta_to_euclidean(theta_ratio, proj_func=sin), ]) micro_coord = np.tile(np.expand_dims(micro_coord_single, 0), [z.shape[0], 1]) generated_patch = self.sess.run( self.gen_micro_test, feed_dict={self.z: z, self.micro_coord_fake: micro_coord}) # TODO all_micro_patches.append(generated_patch) all_micro_coord.append(micro_coord) num_patches = num_patches_x * num_patches_y all_micro_patches = np.concatenate(all_micro_patches, 0) all_micro_patches_reord = self.patch_handler.reord_patches_cpu(all_micro_patches, self.batch_size, num_patches) full_image = self.patch_handler.concat_micro_patches_cpu( all_micro_patches_reord, ratio_over_micro=[num_patches_x, num_patches_y]) return all_micro_patches, full_image def test(self, n_samples, output_dir): n_digits = ceil(log(n_samples, 10)) n_iters = n_samples // self.batch_size + 1 if not os.path.exists(output_dir): os.makedirs(output_dir) for i in trange(n_iters): images = self.rand_sample_full_test() for j in range(images.shape[0]): global_id = i*self.batch_size + j if global_id < n_samples: output_path = os.path.join(output_dir, "test_sample_{}.png".format(str(global_id).zfill(n_digits))) imsave(output_path, images[j]) def train(self, logger, evaluator, global_step): start_time = time.time() g_loss, d_loss, q_loss = 0, 0, 0 z_fixed = self.sample_prior() cur_epoch = int(global_step / self.num_batches) cur_iter = global_step - cur_epoch * self.num_batches while cur_epoch < self.epochs: while cur_iter < self.num_batches: # Create data z_iter = self.sample_prior() macro_coord, micro_coord, y_angle_ratio = self.coord_handler.sample_coord() feed_dict_iter = { self.micro_coord_real: micro_coord, self.macro_coord_real: macro_coord, self.micro_coord_fake: micro_coord, self.macro_coord_fake: macro_coord, self.y_angle_ratio: y_angle_ratio, self.z: z_iter, } feed_dict_fixed = { self.micro_coord_real: micro_coord, self.macro_coord_real: macro_coord, self.micro_coord_fake: micro_coord, self.macro_coord_fake: macro_coord, self.y_angle_ratio: y_angle_ratio, self.z: z_fixed, } # Optimize if (global_step % self.D_update_period) == 0: _, d_summary_str, d_loss = self.sess.run( [self.d_optim, logger.d_summaries, self.d_loss], feed_dict=feed_dict_iter) if (global_step % self.G_update_period) == 0: _, g_summary_str, g_loss = self.sess.run( [self.g_optim, logger.g_summaries, self.g_loss], feed_dict=feed_dict_iter) if self.train_extrap: macro_coord_extrap, micro_coord_extrap, _ = \ self.coord_handler.sample_coord(num_extrap_steps=self.num_extrap_steps) # Override logging inputs as well feed_dict_fixed[self.micro_coord_fake] = micro_coord_extrap feed_dict_fixed[self.macro_coord_fake] = macro_coord_extrap feed_dict_iter[self.micro_coord_fake] = micro_coord_extrap feed_dict_iter[self.macro_coord_fake] = macro_coord_extrap if (global_step % self.D_update_period) == 0: _, d_summary_str, d_loss = self.sess.run( [self.d_optim_extrap, logger.d_summaries, self.d_loss], feed_dict=feed_dict_iter) if (global_step % self.G_update_period) == 0: _, g_summary_str, g_loss = self.sess.run( [self.g_optim_extrap, logger.g_summaries, self.g_loss], feed_dict=feed_dict_iter) if self._train_content_prediction_model() and (global_step % self.Q_update_period) == 0: _, q_loss = self.sess.run( [self.q_optim, self.q_loss], feed_dict=feed_dict_iter) # Log time_elapsed = time.time() - start_time print("[{}] [Epoch: {}; {:4d}/{:4d}; global_step:{}] elapsed: {:.4f}, d: {:.4f}, g: {:.4f}, q: {:.4f}".format( self.exp_name, cur_epoch, cur_iter, self.num_batches, global_step, time_elapsed, d_loss, g_loss, q_loss)) logger.log_iter(self, evaluator, cur_epoch, cur_iter, global_step, g_summary_str, d_summary_str, z_iter, z_fixed, feed_dict_iter, feed_dict_fixed) cur_iter += 1 global_step += 1 cur_epoch += 1 cur_iter = 0