import fasteners import tensorflow as tf from fjcommon import tf_helpers from fjcommon import config_parser from fjcommon import functools_ext import time import os import subprocess import argparse from constants import NUM_PREPROCESS_THREADS, NUM_CROPS_PER_IMG from collections import namedtuple from restore_manager import RestoreManager from saver import Saver import inputpipeline import training_helpers import probclass import autoencoder import logdir_helpers import ms_ssim import bits from logger import Logger import sheets_logger import numpy as np from codec_distance import CodecDistance, CodecDistanceReadException # Enable TF logging output tf.logging.set_verbosity(tf.logging.INFO) _LOG_DIR_FORMAT = """ - LOG DIR ---------------------------------------------------------------------- {} --------------------------------------------------------------------------------""" _STARTING_TRAINING_INFO_STR = """ - STARTING TRAINING ------------------------------------------------------------""" _MAX_METADATA_RUNS = 1 # if --log_run_metadata is given, how many times should run metadata be logged _EPS = 1e-5 TrainFlags = namedtuple( 'TrainFlags', ['log_run_metadata', 'log_interval_train', 'log_interval_test', 'log_interval_save', 'summarize_grads']) Datasets = namedtuple('Datasets', ['train', 'test', 'codec_distance']) # note that (train_autoencoder=True, train_probclass=False) => probclass is still used to calculate H def train(autoencoder_config_path, probclass_config_path, restore_manager: RestoreManager, log_dir_root, datasets: Datasets, train_flags: TrainFlags, ckpt_interval_hours: float, description: str): ae_config, ae_config_rel_path = config_parser.parse(autoencoder_config_path) pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path) print_configs(('ae_config', ae_config), ('pc_config', pc_config)) continue_in_ckpt_dir = restore_manager and restore_manager.continue_in_ckpt_dir if continue_in_ckpt_dir: logdir = restore_manager.log_dir else: logdir = logdir_helpers.create_unique_log_dir( [ae_config_rel_path, pc_config_rel_path], log_dir_root, restore_dir=restore_manager.ckpt_dir if restore_manager else None) print(_LOG_DIR_FORMAT.format(logdir)) if description: _write_to_sheets(logdir_helpers.log_date_from_log_dir(logdir), ae_config_rel_path, pc_config_rel_path, description, git_ref=_get_git_ref(), log_dir_root=log_dir_root, is_continue=continue_in_ckpt_dir) ae_cls = autoencoder.get_network_cls(ae_config) pc_cls = probclass.get_network_cls(pc_config) # Instantiate autoencoder and probability classifier ae = ae_cls(ae_config) pc = pc_cls(pc_config, num_centers=ae_config.num_centers) # train --- ip_train = inputpipeline.InputPipeline( inputpipeline.get_dataset(datasets.train), ae_config.crop_size, batch_size=ae_config.batch_size, shuffle=False, num_preprocess_threads=NUM_PREPROCESS_THREADS, num_crops_per_img=NUM_CROPS_PER_IMG) x_train = ip_train.get_batch() enc_out_train = ae.encode(x_train, is_training=True) # qbar is masked by the heatmap x_out_train = ae.decode(enc_out_train.qbar, is_training=True) # stop_gradient is beneficial for training. it prevents multiple gradients flowing into the heatmap. pc_in = tf.stop_gradient(enc_out_train.qbar) bc_train = pc.bitcost(pc_in, enc_out_train.symbols, is_training=True, pad_value=pc.auto_pad_value(ae)) bpp_train = bits.bitcost_to_bpp(bc_train, x_train) d_train = Distortions(ae_config, x_train, x_out_train, is_training=True) # loss --- total_loss, H_real, pc_comps, ae_comps = get_loss( ae_config, ae, pc, d_train.d_loss_scaled, bc_train, enc_out_train.heatmap) train_op = get_train_op(ae_config, pc_config, ip_train, pc.variables(), total_loss) # test --- with tf.name_scope('test'): ip_test = inputpipeline.InputPipeline( inputpipeline.get_dataset(datasets.test), ae_config.crop_size, batch_size=ae_config.batch_size, num_preprocess_threads=NUM_PREPROCESS_THREADS, num_crops_per_img=1, big_queues=False, shuffle=False) x_test = ip_test.get_batch() enc_out_test = ae.encode(x_test, is_training=False) x_out_test = ae.decode(enc_out_test.qhard, is_training=False) bc_test = pc.bitcost(enc_out_test.qhard, enc_out_test.symbols, is_training=False, pad_value=pc.auto_pad_value(ae)) bpp_test = bits.bitcost_to_bpp(bc_test, x_test) d_test = Distortions(ae_config, x_test, x_out_test, is_training=False) # summing over channel dimension gives 2D heatmap heatmap2D = (tf.reduce_sum(enc_out_test.heatmap, 1) if enc_out_test.heatmap is not None else None) try: # Try to get codec distnace for current dataset codec_distance_ms_ssim = CodecDistance(datasets.codec_distance, codec='bpg', metric='ms-ssim') get_distance = functools_ext.catcher( ValueError, handler=functools_ext.const(np.nan), f=codec_distance_ms_ssim.distance) get_distance = functools_ext.compose(np.float32, get_distance) # cast to float32 d_BPG_test = tf.py_func(get_distance, [bpp_test, d_test.ms_ssim], tf.float32, stateful=False, name='d_BPG') d_BPG_test.set_shape(()) except CodecDistanceReadException as e: print('Cannot compute CodecDistance: {}'.format(e)) d_BPG_test = tf.constant(np.nan, shape=(), name='ConstNaN') # --- train_logger = Logger() test_logger = Logger() distortion_name = ae_config.distortion_to_minimize train_logger.add_summaries(d_train.summaries_with_prefix('train')) # Visualize components of losses train_logger.add_summaries([ tf.summary.scalar('train/PC_loss/{}'.format(name), comp) for name, comp in pc_comps]) train_logger.add_summaries([ tf.summary.scalar('train/AE_loss/{}'.format(name), comp) for name, comp in ae_comps]) train_logger.add_summaries([tf.summary.scalar('train/bpp', bpp_train)]) train_logger.add_console_tensor('loss={:.3f}', total_loss) train_logger.add_console_tensor('ms_ssim={:.3f}', d_train.ms_ssim) train_logger.add_console_tensor('bpp={:.3f}', bpp_train) train_logger.add_console_tensor('H_real={:.3f}', H_real) test_logger.add_summaries(d_test.summaries_with_prefix('test')) test_logger.add_summaries([ tf.summary.scalar('test/bpp', bpp_test), tf.summary.scalar('test/distance_BPG_MS-SSIM', d_BPG_test), tf.summary.image('test/x_in', prep_for_image_summary(x_test, n=3, name='x_in')), tf.summary.image('test/x_out', prep_for_image_summary(x_out_test, n=3, name='x_out'))]) if heatmap2D is not None: test_logger.add_summaries([ tf.summary.image('test/hm', prep_for_grayscale_image_summary(heatmap2D, n=3, autoscale=True, name='hm'))]) test_logger.add_console_tensor('ms_ssim={:.3f}', d_test.ms_ssim) test_logger.add_console_tensor('bpp={:.3f}', bpp_test) test_logger.add_summaries([ tf.summary.histogram('centers', ae.get_centers_variable()), tf.summary.histogram('test/qbar', enc_out_test.qbar[:ae_config.batch_size//2, ...])]) test_logger.add_console_tensor('d_BPG={:.6f}', d_BPG_test) test_logger.add_console_tensor(Logger.Numpy1DFormatter('centers={}'), ae.get_centers_variable()) print('Starting session and queues...') with tf_helpers.start_queues_in_sess(init_vars=restore_manager is None) as (sess, coord): train_logger.finalize_with_sess(sess) test_logger.finalize_with_sess(sess) if restore_manager: restore_manager.restore(sess) saver = Saver(Saver.ckpt_dir_for_log_dir(logdir), max_to_keep=1, keep_checkpoint_every_n_hours=ckpt_interval_hours) train_loop(ae_config, sess, coord, train_op, train_logger, test_logger, train_flags, logdir, saver, is_restored=restore_manager is not None) def print_configs(*configs_with_names): print('\n---\n'.join('Using {}:\n{}'.format(name, config) for name, config in configs_with_names)) class _Timer(object): def __init__(self, log_interval, batch_size): self.log_interval = log_interval self.batch_size = batch_size self.start_time = time.time() def get_avg_ex_per_sec(self): avg_time_per_step = (time.time() - self.start_time) / self.log_interval avg_ex_per_sec = self.batch_size / avg_time_per_step return avg_ex_per_sec def reset(self): self.start_time = time.time() def train_loop( config, sess, coord, train_op, train_logger: Logger, test_logger: Logger, train_flags: TrainFlags, log_dir, saver: Saver, is_restored=False): global_step = tf.train.get_or_create_global_step() job_id = logdir_helpers.log_date_from_log_dir(log_dir) fw = tf.summary.FileWriter(log_dir, graph=sess.graph) training_timer = _Timer(train_flags.log_interval_train, config.batch_size) itr = 0 num_metadata_runs = 0 if is_restored: itr = sess.run(global_step) train_logger.log().to_tensorboard(fw, itr).to_console(itr, append='Restored') test_logger.log().to_tensorboard(fw, itr).to_console(itr) print(_STARTING_TRAINING_INFO_STR) while not coord.should_stop(): if (train_flags.log_run_metadata and num_metadata_runs < _MAX_METADATA_RUNS and (itr % (train_flags.log_interval_train - 1) == 0)): print('Logging run metadata...', end=' ') num_metadata_runs += 1 (_, itr), run_metadata = run_and_fetch_metadata([train_op, global_step], sess) fw.add_run_metadata(run_metadata, str(itr), itr) print('Done') else: _, itr = sess.run([train_op, global_step]) # Train Logging -- if itr % train_flags.log_interval_train == 0: info_str = '(img/s: {:.1f}) {}'.format(training_timer.get_avg_ex_per_sec(), job_id) train_logger.log().to_tensorboard(fw, itr).to_console(itr, append=info_str) # Save -- if itr % train_flags.log_interval_save == 0: print('Saving...') saver.save(sess, global_step) # Test Logging -- if train_flags.log_interval_test > 0 and itr % train_flags.log_interval_test == 0: test_logger.log().to_tensorboard(fw, itr).to_console(itr) if itr % train_flags.log_interval_train == 0: # Reset after all above for accurate timings training_timer.reset() def run_and_fetch_metadata(fetches, sess): print('*** Adding metadata...') run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() return sess.run(fetches, options=run_options, run_metadata=run_metadata), run_metadata def prep_for_image_summary(t, n=3, autoscale=False, name='img'): """ given tensor t of shape NCHW, return t[:n, ...] transposed to NHWC, cast to uint8 """ assert int(t.shape[1]) == 3, 'Expected N3HW, got {}'.format(t) with tf.name_scope('prep_' + name): t = tf_helpers.transpose_NCHW_to_NHWC(t[:n, ...]) if autoscale: # if t is float32, tf.summary.image will automatically rescale assert tf.float32.is_compatible_with(t.dtype) return t else: # if t is uint8, tf.summary.image will NOT automatically rescale return tf.cast(t, tf.uint8, 'uint8') def prep_for_grayscale_image_summary(t, n=3, autoscale=False, name='img'): assert len(t.shape) == 3 with tf.name_scope('prep_' + name): t = t[:n, ...] t = tf.expand_dims(t, -1) # NHW1 if autoscale: assert tf.float32.is_compatible_with(t.dtype) return t else: return tf.cast(t, tf.uint8, name='uint8') def get_loss(config, ae, pc, d_loss_scaled, bc, heatmap): assert config.H_target heatmap_enabled = heatmap is not None with tf.name_scope('losses'): bc_mask = (bc * heatmap) if heatmap_enabled else bc H_real = tf.reduce_mean(bc, name='H_real') H_mask = tf.reduce_mean(bc_mask, name='H_mask') H_soft = 0.5 * (H_mask + H_real) H_target = tf.constant(config.H_target, tf.float32, name='H_target') beta = tf.constant(config.beta, tf.float32, name='beta') pc_loss = beta * tf.maximum(H_soft - H_target, 0) # Adding Regularizers with tf.name_scope('regularization_losses'): reg_probclass = pc.regularization_loss() if reg_probclass is None: reg_probclass = 0 reg_enc = ae.encoder_regularization_loss() reg_dec = ae.decoder_regularization_loss() reg_loss = reg_probclass + reg_enc + reg_dec pc_comps = [('H_mask', H_mask), ('H_real', H_real), ('pc_loss', pc_loss), ('reg', reg_probclass)] ae_comps = [('d_loss_scaled', d_loss_scaled), ('reg_enc_dec', reg_enc + reg_dec)] total_loss = d_loss_scaled + pc_loss + reg_loss return total_loss, H_real, pc_comps, ae_comps def get_train_op(ae_config, pc_config, input_pipeline_train, vars_probclass, total_loss): lr_ae = training_helpers.create_learning_rate_tensor(ae_config, input_pipeline_train, name='lr_ae') default_optimizer = training_helpers.create_optimizer(ae_config, lr_ae, name='Adam_AE') lr_pc = training_helpers.create_learning_rate_tensor(pc_config, input_pipeline_train, name='lr_pc') optimizer_pc = training_helpers.create_optimizer(pc_config, lr_pc, name='Adam_PC') special_optimizers_and_vars = [(optimizer_pc, vars_probclass)] return tf_helpers.create_train_op_with_different_lrs( total_loss, default_optimizer, special_optimizers_and_vars, summarize_gradients=False) class Distortions(object): def __init__(self, config, x, x_out, is_training): assert tf.float32.is_compatible_with(x.dtype) and tf.float32.is_compatible_with(x_out.dtype) self.config = config with tf.name_scope('distortions_train' if is_training else 'distortions_test'): minimize_for = config.distortion_to_minimize assert minimize_for in ('mse', 'psnr', 'ms_ssim') # don't calculate MS-SSIM if not necessary to speed things up should_get_ms_ssim = minimize_for == 'ms_ssim' # if we don't minimize for PSNR, cast x and x_out to int before calculating the PSNR, because otherwise # PSNR is off. If not training, always cast to int, because we don't need the gradients. # equivalent for when we don't minimize for MSE cast_to_int_for_psnr = (not is_training) or minimize_for != 'psnr' cast_to_int_for_mse = (not is_training) or minimize_for != 'mse' self.mse = self.mean_over_batch( Distortions.get_mse_per_img(x, x_out, cast_to_int_for_mse), name='mse') self.psnr = self.mean_over_batch( Distortions.get_psnr_per_image(x, x_out, cast_to_int_for_psnr), name='psnr') self.ms_ssim = ( Distortions.get_ms_ssim(x, x_out) if should_get_ms_ssim else None) with tf.name_scope('distortion_to_minimize'): self.d_loss_scaled = self._get_distortion_to_minimize(minimize_for) def summaries_with_prefix(self, prefix): return tf_helpers.list_without_None( tf.summary.scalar(prefix + '/mse', self.mse), tf.summary.scalar(prefix + '/psnr', self.psnr), tf.summary.scalar(prefix + '/ms_ssim', self.ms_ssim) if self.ms_ssim is not None else None) def _get_distortion_to_minimize(self, minimize_for): """ Returns a float32 that should be minimized in training. For PSNR and MS-SSIM, which increase for a decrease in distortion, a suitable factor is added. """ if minimize_for == 'mse': return self.mse if minimize_for == 'psnr': return self.config.K_psnr - self.psnr if minimize_for == 'ms_ssim': return self.config.K_ms_ssim * (1 - self.ms_ssim) raise ValueError('Invalid: {}'.format(minimize_for)) @staticmethod def mean_over_batch(d, name): assert len(d.shape) == 1, 'Expected tensor of shape (N,), got {}'.format(d.shape) with tf.name_scope('mean_' + name): return tf.reduce_mean(d, name='mean') @staticmethod def get_mse_per_img(inp, otp, cast_to_int): """ :param inp: NCHW :param otp: NCHW :param cast_to_int: if True, both inp and otp are casted to int32 before the error is calculated, to ensure real world errors (image pixels are always quantized). But the error is always casted back to float32 before a mean per image is calculated and returned :return: float32 tensor of shape (N,) """ with tf.name_scope('mse_{}'.format('int' if cast_to_int else 'float')): if cast_to_int: # Values are expected to be in 0...255, i.e., uint8, but tf.square does not support uint8's inp, otp = tf.cast(inp, tf.int32), tf.cast(otp, tf.int32) squared_error = tf.square(otp - inp) squared_error_float = tf.to_float(squared_error) mse_per_image = tf.reduce_mean(squared_error_float, axis=[1, 2, 3]) return mse_per_image @staticmethod def get_psnr_per_image(inp, otp, cast_to_int): with tf.name_scope('psnr_{}'.format('int' if cast_to_int else 'float')): mse_per_image = Distortions.get_mse_per_img(inp, otp, cast_to_int) psnr_per_image = 10 * tf_helpers.log10(255.0 * 255.0 / mse_per_image) return psnr_per_image @staticmethod def get_ms_ssim(inp, otp): with tf.name_scope('mean_MS_SSIM'): return ms_ssim.MultiScaleSSIM(inp, otp, data_format='NCHW', name='MS-SSIM') def _print_trainable_variables(): print('*** tf.trainable_variables:') for v in tf.trainable_variables(): print(v) print('*** TRAINABLE_RESOURCE_VARIABLES:') for v in tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES): print(v) def _write_to_sheets(log_date, ae_config_rel_path, pc_config_rel_path, description, git_ref, log_dir_root, is_continue): try: with fasteners.InterProcessLock(sheets_logger.get_lock_file_p()): sheets_logger.insert_row( log_date + ('c' if is_continue else ''), os.environ.get('JOB_ID', 'N/A'), ae_config_rel_path, pc_config_rel_path, description, '', git_ref, log_dir_root) except sheets_logger.GoogleSheetsAccessFailedException as e: print(e) def _get_git_ref(): """ :return HEAD commit as given by $QSUBA_GIT_REF """ try: qsuba_git_ref = os.environ['QSUBA_GIT_REF'] if 'tags' in qsuba_git_ref: return qsuba_git_ref git_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode() return '{} ({})'.format(qsuba_git_ref, git_commit[:16]) except KeyError: return '' def main(): p = argparse.ArgumentParser() p.add_argument('autoencoder_config_path') p.add_argument('probclass_config_path') p.add_argument('--dataset_train', '-dtrain', default='imgnet_train', help=inputpipeline.get_dataset.__doc__) p.add_argument('--dataset_test', '-dtest', default='imgnet_test', help=inputpipeline.get_dataset.__doc__) p.add_argument('--dataset_codec_distance', '-dcodec', default='testset', help='See codec_distance.py') p.add_argument('--log_dir_root', '-o', default='logs', metavar='LOG_DIR_ROOT') p.add_argument('--log_interval_train', '-ltrain', type=int, default=100) p.add_argument('--log_interval_save', '-lsave', type=int, default=1000) p.add_argument('--log_interval_test', '-ltest', type=int, default=1000, help='Set to -1 to skip testing, which saves memory.') p.add_argument('--log_run_metadata', '-lmeta', action='store_const', const=True) # TODO: rm p.add_argument('--summarize_gradients', '-lgrads', action='store_const', const=True) p.add_argument('--temporary', '-t', action='store_const', const=True, help='Append _TMP to LOG_DIR_ROOT') p.add_argument('--from_identity', metavar='IDENTITY_CKPT_DIR', help='Like --restore IDENTITY_CKPT_DIR, but global_step and any variables matching *Adam* are not ' 'restored and centers get sampled from the bottleneck with KMeans.') p.add_argument('--restore', '-r', metavar='RESTORE_DIR', help='Path to ckpt dir to restore from.') p.add_argument('--restore_itr', '-i', type=int, default=-1, help='Iteration to restore from. Use -1 for latest. Otherwise, restores the latest checkpoint ' 'with iteration <= restore_itr') p.add_argument('--restore_continue', action='store_const', const=True, help='If given, the log dir corresponding to the path given by RESTORE_DIR will be used to save ' 'future logs and checkpoints.') p.add_argument('--restore_skip_vars', type=str, help='Var names to skip, use comma to separate, e.g. "Adam, global_var".') p.add_argument('--ckpt_interval', type=float, default=1, help='How often to keep checkpoints, in hours.') p.add_argument('--description', '-d', type=str, help='Description, if given, is appended to Google Sheets') flags = p.parse_args() if flags.temporary: print('*** WARN: --temporary') time.sleep(1.5) flags.log_dir_root = flags.log_dir_root.rstrip(os.path.sep) + '_TMP' train_flags = TrainFlags( log_run_metadata=flags.log_run_metadata, log_interval_train=flags.log_interval_train, log_interval_test=flags.log_interval_test, log_interval_save=flags.log_interval_save, summarize_grads=flags.summarize_gradients) tf.set_random_seed(1234) train(autoencoder_config_path=flags.autoencoder_config_path, probclass_config_path=flags.probclass_config_path, restore_manager=RestoreManager.from_flags(flags), datasets=Datasets(flags.dataset_train, flags.dataset_test, flags.dataset_codec_distance), log_dir_root=flags.log_dir_root, train_flags=train_flags, ckpt_interval_hours=flags.ckpt_interval, description=flags.description if not flags.temporary else None) if __name__ == '__main__': main()