# Copyright 2018 The Defense-GAN Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================= """Contains the abstract class for models.""" import os import tensorflow as tf import yaml from utils.misc import ensure_dir from tensorflow.contrib import slim from utils.dummy import DummySummaryWriter class AbstractModel(object): def __init__(self, default_properties, test_mode=False, verbose=True, cfg=None, **args): """The abstract model that the other models extend. Args: default_properties: The attributes of an experiment, read from a config file test_mode: If in the test mode, computation graph for loss will not be constructed, config will be saved in the output directory verbose: If true, prints debug information cfg: Config dictionary args: The rest of the arguments which can become object attributes """ # Set attributes either from FLAGS or **args. self.cfg = cfg # Active session parameter. self.active_sess = None # Object attributes. default_properties.extend( ['tensorboard_log', 'output_dir', 'num_gpus']) self.default_properties = default_properties self.initialized = False self.verbose = verbose self.output_dir = 'output' local_vals = locals() args.update(local_vals) for attr in default_properties: if attr in args.keys(): self._set_attr(attr, args[attr]) else: self._set_attr(attr, None) # Runtime attributes. self.saver = None self.global_step = tf.train.get_or_create_global_step() self.global_step_inc = \ tf.assign(self.global_step, tf.add(self.global_step, 1)) # Phase: 1 train 0 test. self.is_training = tf.placeholder(dtype=tf.bool) self.save_vars = {} self.save_var_prefixes = [] self.dataset = None self.test_mode = test_mode self._set_checkpoint_dir() self._build() if not test_mode: self._save_cfg_in_ckpt() self._loss() self._initialize_summary_writer() def _load_dataset(self): pass def _build(self): pass def _loss(self): pass def test(self, input): pass def train(self): pass def _verbose_print(self, message): """Handy verbose print function""" if self.verbose: print(message) def _save_cfg_in_ckpt(self): """Saves the configuration in the experiment's output directory.""" final_cfg = {} if hasattr(self, 'cfg'): for k in self.cfg.keys(): if hasattr(self, k.lower()): if getattr(self, k.lower()) is not None: final_cfg[k] = getattr(self, k.lower()) if not self.test_mode: with open(os.path.join(self.checkpoint_dir, 'cfg.yml'), 'w') as f: yaml.dump(final_cfg, f) def _set_attr(self, attr_name, val): """Sets an object attribute from FLAGS if it exists, if not it prints out an error. Note that FLAGS is set from config and command line inputs. Args: attr_name: The name of the field. val: The value, if None it will set it from tf.apps.flags.FLAGS """ FLAGS = tf.app.flags.FLAGS if val is None: if hasattr(FLAGS, attr_name): val = getattr(FLAGS, attr_name) elif hasattr(self, 'cfg'): if attr_name.upper() in self.cfg.keys(): val = self.cfg[attr_name.upper()] elif attr_name.lower() in self.cfg.keys(): val = self.cfg[attr_name.lower()] if val is None and self.verbose: print( '[-] {}.{} is not set.'.format(type(self).__name__, attr_name)) setattr(self, attr_name, val) if self.verbose: print('[#] {}.{} is set to {}.'.format(type(self).__name__, attr_name, val)) def imsave_transform(self, imgs): return imgs def get_learning_rate(self, init_lr=None, decay_epoch=None, decay_mult=None, iters_per_epoch=None, decay_iter=None, global_step=None, decay_lr=True): """Prepares the learning rate. Args: init_lr: The initial learning rate decay_epoch: The epoch of decay decay_mult: The decay factor iters_per_epoch: Number of iterations per epoch decay_iter: The iteration of decay [either this or decay_epoch should be set] global_step: decay_lr: Returns: `tf.Tensor` of the learning rate. """ if init_lr is None: init_lr = self.learning_rate if global_step is None: global_step = self.global_step if decay_epoch: assert iters_per_epoch if iters_per_epoch is None: iters_per_epoch = self.iters_per_epoch else: assert decay_iter if decay_lr: if decay_epoch: decay_iter = decay_epoch * iters_per_epoch return tf.train.exponential_decay(init_lr, global_step, decay_iter, decay_mult, staircase=True) else: return tf.constant(self.learning_rate) def _set_checkpoint_dir(self): """Sets the directory containing snapshots of the model.""" self.cfg_file = self.cfg['cfg_path'] if 'cfg.yml' in self.cfg_file: ckpt_dir = os.path.dirname(self.cfg_file) else: ckpt_dir = os.path.join(self.output_dir, self.cfg_file.replace('experiments/cfgs/', '').replace( 'cfg.yml', '').replace( '.yml', '')) if not self.test_mode: postfix = '' ignore_list = ['dataset', 'cfg_file', 'batch_size'] if hasattr(self, 'cfg'): if self.cfg is not None: for prop in self.default_properties: if prop in ignore_list: continue if prop.upper() in self.cfg.keys(): self_val = getattr(self, prop) if self_val is not None: if getattr(self, prop) != self.cfg[ prop.upper()]: postfix += '-{}={}'.format( prop, self_val).replace('.', '_') ckpt_dir += postfix ensure_dir(ckpt_dir) self.checkpoint_dir = ckpt_dir self.debug_dir = self.checkpoint_dir.replace('output', 'debug') ensure_dir(self.debug_dir) def _initialize_summary_writer(self): # Setup the summary writer. if not self.tensorboard_log: self.summary_writer = DummySummaryWriter() else: sum_dir = os.path.join(self.checkpoint_dir, 'tb_logs') if not os.path.exists(sum_dir): os.makedirs(sum_dir) self.summary_writer = tf.summary.FileWriter(sum_dir) def _initialize_saver(self, prefixes=None, force=False, max_to_keep=5): """Initializes the saver object. Args: prefixes: The prefixes that the saver should take care of. force (optional): Even if saver is set, reconstruct the saver object. max_to_keep (optional): """ if self.saver is not None and not force: return else: if prefixes is None or not ( type(prefixes) != list or type(prefixes) != tuple): raise ValueError( 'Prefix of variables that needs saving are not defined') prefixes_str = '' for pref in prefixes: prefixes_str = prefixes_str + pref + ' ' print('[#] Initializing it with variable prefixes: {}'.format( prefixes_str)) saved_vars = [] for pref in prefixes: saved_vars.extend(slim.get_variables(pref)) self.saver = tf.train.Saver(saved_vars, max_to_keep=max_to_keep) def set_session(self, sess): """""" if self.active_sess is None: self.active_sess = sess else: raise EnvironmentError("Session is already set.") @property def sess(self): if self.active_sess is None: config = tf.ConfigProto() config.gpu_options.allow_growth = True self.active_sess = tf.Session(config=config) return self.active_sess def close_session(self): if self.active_sess: self.active_sess.close() def load(self, checkpoint_dir=None, prefixes=None, saver=None): """Loads the saved weights to the model from the checkpoint directory Args: checkpoint_dir: The path to saved models """ if prefixes is None: prefixes = self.save_var_prefixes if self.saver is None: print('[!] Saver is not initialized') self._initialize_saver(prefixes=prefixes) if saver is None: saver = self.saver if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir if not os.path.isdir(checkpoint_dir): try: saver.restore(self.sess, checkpoint_dir) except: print(" [!] Failed to find a checkpoint at {}".format( checkpoint_dir)) else: print(" [-] Reading checkpoints... {} ".format(checkpoint_dir)) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) else: print( " [!] Failed to find a checkpoint " "within directory {}".format(checkpoint_dir)) return False print(" [*] Checkpoint is read successfully from {}".format( checkpoint_dir)) return True def add_save_vars(self, prefixes): """Prepares the list of variables that should be saved based on their name prefix. Args: prefixes: Variable name prefixes to find and save. """ for pre in prefixes: pre_vars = slim.get_variables(pre) self.save_vars.update(pre_vars) var_list = '' for var in self.save_vars: var_list = var_list + var.name + ' ' print ('Saving these variables: {}'.format(var_list)) def input_transform(self, images): pass def input_pl_transform(self): self.real_data = self.input_transform(self.real_data_pl) self.real_data_test = self.input_transform(self.real_data_test_pl) def initialize_uninitialized(self, ): """Only initializes the variables of a TensorFlow session that were not already initialized. """ # List all global variables. sess = self.sess global_vars = tf.global_variables() # Find initialized status for all variables. is_var_init = [tf.is_variable_initialized(var) for var in global_vars] is_initialized = sess.run(is_var_init) # List all variables that were not previously initialized. not_initialized_vars = [var for (var, init) in zip(global_vars, is_initialized) if not init] for v in not_initialized_vars: print('[!] not init: {}'.format(v.name)) # Initialize all uninitialized variables found, if any. if len(not_initialized_vars): sess.run(tf.variables_initializer(not_initialized_vars)) def save(self, prefixes=None, global_step=None, checkpoint_dir=None): if global_step is None: global_step = self.global_step if checkpoint_dir is None: checkpoint_dir = self._set_checkpoint_dir ensure_dir(checkpoint_dir) self._initialize_saver(prefixes) self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_save_name), global_step=global_step) print('Saved at iter {} to {}'.format(self.sess.run(global_step), checkpoint_dir)) def initialize(self, dir): self.load(dir) self.initialized = True