import tensorflow as tf import numpy as np import argparse import logdir_helpers import val_images import constants import autoencoder import probclass import glob import bits import imageio import ms_ssim_np from codec_distance import CodecDistance, CodecDistanceReadException from fjcommon import tf_helpers, config_parser from images_iterator import ImagesIterator from saver import Saver import os from os import path import skimage.measure from collections import defaultdict from collections import namedtuple from val_files import ValidationDirs, MeasuresWriter import bpp_helpers _VALIDATION_INFO_STR = """ - VALIDATION -------------------------------------------------------------------""" _CKPT_ITR_INFO_STR = """- Validating ckpt {} ----------""" OutputFlags = namedtuple('OutputFlags', ['save_ours', 'ckpt_step', 'real_bpp']) def validate(val_dirs: ValidationDirs, images_iterator: ImagesIterator, flags: OutputFlags): """ Saves in val_dirs.log_dir/val/dataset_name/measures.csv: - `img_name,bpp,psnr,ms-ssim forall img_name` """ print(_VALIDATION_INFO_STR) validated_checkpoints = val_dirs.get_validated_checkpoints() # :: [10000, 18000, ..., 256000], ie, [int] all_ckpts = Saver.all_ckpts_with_iterations(val_dirs.ckpt_dir) if len(all_ckpts) == 0: print('No checkpoints found in {}'.format(val_dirs.ckpt_dir)) return # if ckpt_step is -1, then all_ckpt[:-1:flags.ckpt_step] === [] because of how strides work ckpt_to_check = all_ckpts[:-1:flags.ckpt_step] + [all_ckpts[-1]] # every ckpt_step-th checkpoint plus the last one if flags.ckpt_step == -1: assert len(ckpt_to_check) == 1 print('Validating {}/{} checkpoints (--ckpt_step {})...'.format( len(ckpt_to_check), len(all_ckpts), flags.ckpt_step)) missing_checkpoints = [(ckpt_itr, ckpt_path) for ckpt_itr, ckpt_path in ckpt_to_check if ckpt_itr not in validated_checkpoints] if len(missing_checkpoints) == 0: print('All checkpoints validated, stopping...') return # --- # create networks autoencoder_config_path, probclass_config_path = logdir_helpers.config_paths_from_log_dir( val_dirs.log_dir, base_dirs=[constants.CONFIG_BASE_AE, constants.CONFIG_BASE_PC]) ae_config, ae_config_rel_path = config_parser.parse(autoencoder_config_path) pc_config, pc_config_rel_path = config_parser.parse(probclass_config_path) ae_cls = autoencoder.get_network_cls(ae_config) pc_cls = probclass.get_network_cls(pc_config) # Instantiate autoencoder and probability classifier ae = ae_cls(ae_config) pc = pc_cls(pc_config, num_centers=ae_config.num_centers) x_val_ph = tf.placeholder(tf.uint8, (3, None, None), name='x_val_ph') x_val_uint8 = tf.expand_dims(x_val_ph, 0, name='batch') x_val = tf.to_float(x_val_uint8, name='x_val') enc_out_val = ae.encode(x_val, is_training=False) x_out_val = ae.decode(enc_out_val.qhard, is_training=False) bc_val = pc.bitcost(enc_out_val.qbar, enc_out_val.symbols, is_training=False, pad_value=pc.auto_pad_value(ae)) bpp_val = bits.bitcost_to_bpp(bc_val, x_val) x_out_val_uint8 = tf.cast(x_out_val, tf.uint8, name='x_out_val_uint8') # Using numpy implementation due to dynamic shapes msssim_val = ms_ssim_np.tf_msssim_np(x_val_uint8, x_out_val_uint8, data_format='NCHW') psnr_val = psnr_np(x_val_uint8, x_out_val_uint8) restorer = Saver(val_dirs.ckpt_dir, var_list=Saver.get_var_list_of_ckpt_dir(val_dirs.ckpt_dir)) # create fetch_dict fetch_dict = { 'bpp': bpp_val, 'ms-ssim': msssim_val, 'psnr': psnr_val, } if flags.real_bpp: fetch_dict['sym'] = enc_out_val.symbols # NCHW if flags.save_ours: fetch_dict['img_out'] = x_out_val_uint8 # --- fw = tf.summary.FileWriter(val_dirs.out_dir, graph=tf.get_default_graph()) def full_summary_tag(summary_name): return '/'.join(['val', images_iterator.dataset_name, summary_name]) # Distance try: codec_distance_ms_ssim = CodecDistance(images_iterator.dataset_name, codec='bpg', metric='ms-ssim') codec_distance_psnr = CodecDistance(images_iterator.dataset_name, codec='bpg', metric='psnr') except CodecDistanceReadException as e: # no codec distance values stored for the current setup print('*** Distance to BPG not available for {}:\n{}'.format(images_iterator.dataset_name, str(e))) codec_distance_ms_ssim = None codec_distance_psnr = None # Note that for each checkpoint, the structure of the network will be the same. Thus the pad depending image # loading can be cached. # create session with tf_helpers.create_session() as sess: if flags.real_bpp: pred = probclass.PredictionNetwork(pc, pc_config, ae.get_centers_variable(), sess) checker = probclass.ProbclassNetworkTesting(pc, ae, sess) bpp_fetcher = bpp_helpers.BppFetcher(pred, checker) fetcher = sess.make_callable(fetch_dict, feed_list=[x_val_ph]) last_ckpt_itr = missing_checkpoints[-1][0] for ckpt_itr, ckpt_path in missing_checkpoints: if not ckpt_still_exists(ckpt_path): # May happen if job is still training print('Checkpoint disappeared: {}'.format(ckpt_path)) continue print(_CKPT_ITR_INFO_STR.format(ckpt_itr)) restorer.restore_ckpt(sess, ckpt_path) values_aggregator = ValuesAggregator('bpp', 'ms-ssim', 'psnr') # truncates the previous measures.csv file! This way, only the last valid checkpoint is saved. measures_writer = MeasuresWriter(val_dirs.out_dir) # ---------------------------------------- # iterate over images # images are padded to work with current auto encoder for img_i, (img_name, img_content) in enumerate(images_iterator.iter_imgs(pad=ae.get_subsampling_factor())): otp = fetcher(img_content) measures_writer.append(img_name, otp) if flags.real_bpp: # Calculate bpp_real, bpp_theory = bpp_fetcher.get_bpp( otp['sym'], bpp_helpers.num_pixels_in_image(img_content)) # Logging bpp_loss = otp['bpp'] diff_percent_tr = (bpp_theory/bpp_real) * 100 diff_percent_lt = (bpp_loss/bpp_theory) * 100 print('BPP: Real {:.5f}\n' ' Theoretical: {:.5f} [{:5.1f}% of real]\n' ' Loss: {:.5f} [{:5.1f}% of real]'.format( bpp_real, bpp_theory, diff_percent_tr, bpp_loss, diff_percent_lt)) assert abs(bpp_theory - bpp_loss) < 1e-3, 'Expected bpp_theory to match loss! Got {} and {}'.format( bpp_theory, bpp_loss) if flags.save_ours and ckpt_itr == last_ckpt_itr: save_img(img_name, otp['img_out'], val_dirs) values_aggregator.update(otp) print('{: 10d} {img_name} | Mean: {avgs}'.format( img_i, img_name=img_name, avgs=values_aggregator.averages_str()), end=('\r' if not flags.real_bpp else '\n'), flush=True) measures_writer.close() print() # add newline avgs = values_aggregator.averages() avg_bpp, avg_ms_ssim, avg_psnr = avgs['bpp'], avgs['ms-ssim'], avgs['psnr'] tf_helpers.log_values(fw, [(full_summary_tag('avg_bpp'), avg_bpp), (full_summary_tag('avg_ms_ssim'), avg_ms_ssim), (full_summary_tag('avg_psnr'), avg_psnr)], iteration=ckpt_itr) if codec_distance_ms_ssim and codec_distance_psnr: try: d_ms_ssim = codec_distance_ms_ssim.distance(avg_bpp, avg_ms_ssim) d_pnsr = codec_distance_psnr.distance(avg_bpp, avg_psnr) print('Distance to BPG: {:.3f} ms-ssim // {:.3f} psnr'.format(d_ms_ssim, d_pnsr)) tf_helpers.log_values(fw, [(full_summary_tag('distance_BPG_MS-SSIM'), d_ms_ssim), (full_summary_tag('distance_BPG_PSNR'), d_pnsr)], iteration=ckpt_itr) except ValueError as e: # out of range errors from distance calls print(e) val_dirs.add_validated_checkpoint(ckpt_itr) print('Validation completed {}'.format(val_dirs)) def save_img(img_name, img_out, val_dirs): assert img_name.endswith('.png') assert img_out.ndim == 4 and img_out.shape[1] == 3, 'Expected NCHW, got {}'.format(img_out) img_dir = path.join(val_dirs.out_dir, 'imgs') os.makedirs(img_dir, exist_ok=True) img_out = np.transpose(img_out[0, :, :, :], (1, 2, 0)) # Make HWC img_out_p = path.join(img_dir, img_name) print('Saving {}...'.format(img_out_p)) imageio.imsave(img_out_p, img_out) def psnr_np(img1, img2): assert tf.uint8.is_compatible_with(img1.dtype), 'Expected uint8 intput' assert tf.uint8.is_compatible_with(img2.dtype), 'Expected uint8 intput' def _psnr(_img1, _img2): return np.float32(skimage.measure.compare_psnr(_img1, _img2)) with tf.name_scope('psnr_np'): v = tf.py_func(_psnr, [img1, img2], tf.float32, stateful=False, name='PSNR') v.set_shape(()) return v class ValuesAggregator(object): def __init__(self, *tags_to_agregate): self._tags_to_values = defaultdict(list) # log tag -> [log value] self.tags_to_agregate = tags_to_agregate def update(self, fetch_dict_out): for tag, value in fetch_dict_out.items(): if tag in self.tags_to_agregate: assert not np.isnan(value), 'nan encountered in {}'.format(fetch_dict_out) self._tags_to_values[tag].append(value) def averages(self): return {tag: np.mean(values) for tag, values in self._tags_to_values.items()} def averages_str(self, joiner=', '): mean_values = self.averages() avergaes_sorted = tuple((tag, mean_values[tag]) for tag in self.tags_to_agregate) # sort by tags_to_agregate return joiner.join('{}: {:.3f}'.format(tag, value) for tag, value in avergaes_sorted) def ckpt_still_exists(ckpt_path): ckpt_files = glob.glob(ckpt_path + '*') return len(ckpt_files) > 0 def main(): p = argparse.ArgumentParser() p.add_argument('log_dir_root', help='Path to dir containing log_dirs.') p.add_argument('job_ids', help='Comma separated list of job_ids.') p.add_argument('images') p.add_argument('--save_ours', '-o', action='store_const', const=True, help='If given, store output images in VAL_OUT/imgs.') p.add_argument('--how_many', type=int, help='Number of images to output') p.add_argument('--image_cache_max', '-cache', type=int, default=500, help='Cache max in [MB]. Set to 0 to disable.') p.add_argument('--restore_itr', '-i', type=int) p.add_argument('--ckpt_step', '-s', type=int, default=2, help='Every CKPT_STEP-th checkpoint will be validated. Set to 1 to validate all of them. ' 'Last checkpoint will always be validated. Set to -1 to only validate last.') p.add_argument('--reset', action='store_const', const=True, help='Remove previous output') p.add_argument('--real_bpp', action='store_const', const=True, help='If given, calculate real bpp using arithmetic encoding. Note: in our experiments, ' 'this matches the theoretical bpp up to 1% precision. Note: this is very slow.') flags, unknown_flags = p.parse_known_args() if unknown_flags: print('Unknown flags: {}'.format(unknown_flags)) image_paths, dataset_name = val_images.get_image_paths(flags.images) images_iterator = ImagesIterator(image_paths[:flags.how_many], dataset_name, flags.image_cache_max) val_flags = OutputFlags(flags.save_ours, flags.ckpt_step, flags.real_bpp) for ckpt_dir in logdir_helpers.iter_ckpt_dirs(flags.log_dir_root, flags.job_ids): try: validate(ValidationDirs(ckpt_dir, flags.log_dir_root, dataset_name, flags.reset), images_iterator, val_flags) except tf.errors.NotFoundError as e: # happens if ckpt was deleted while validation print('*** Caught {}'.format(e)) continue tf.reset_default_graph() print('*** All given job_ids validated.') if __name__ == '__main__': main()