import tensorflow as tf from tensorflow.python.client import timeline from core.Log import log from core.Util import average_gradients, clip_gradients from core import Measures, Extractions from core.Measures import accumulate_measures from core.Extractions import accumulate_extractions class Trainer: def __init__(self, config, train_network, test_network, global_step, session): self.opt_str = config.string("optimizer", "adam").lower() self.train_network = train_network self.test_network = test_network self.session = session self.global_step = global_step self.validation_step_number = 0 self.gradient_clipping = config.float("gradient_clipping", -1.0) self.learning_rates = config.int_key_dict("learning_rates") self.learning_rate_keys_are_steps = config.bool("learning_rate_keys_are_steps", False) self.curr_learning_rate = self.learning_rates[1] self.lr_var = tf.placeholder(tf.float32, shape=[], name="learning_rate") self.loss_scale_var = tf.placeholder_with_default(1.0, shape=[], name="loss_scale") self.use_gradient_checkpointing = config.bool("use_gradient_checkpointing", False) if self.use_gradient_checkpointing: import memory_saving_gradients from tensorflow.python.ops import gradients #new_grad_fun = memory_saving_gradients.gradients_speed new_grad_fun = memory_saving_gradients.gradients_collection tf.__dict__["gradients"] = new_grad_fun gradients.__dict__["gradients"] = new_grad_fun self.opt, self.reset_opt_op = self.create_optimizer(config) self.collect_run_metadata = config.bool("collect_run_metadata", False) grad_norm = None if train_network is not None: self._step_op, grad_norm = self.create_step_op_and_grad_norm() self._update_ops = self.train_network.update_ops else: self._step_op = None self._update_ops = None # if step_num % summary_interval == 0, extract and write summaries; always write summaries if summary_interval==1 self.summary_interval = config.int("summary_interval", 1) self.summary_writer, self.summary_op_train, self.summary_op_test = self.init_summaries(config, grad_norm) def create_optimizer(self, config): momentum = config.float("momentum", 0.9) if self.opt_str == "sgd_nesterov": return tf.train.MomentumOptimizer(self.lr_var, momentum, use_nesterov=True), None elif self.opt_str == "sgd_momentum": return tf.train.MomentumOptimizer(self.lr_var, momentum), None elif self.opt_str == "sgd": return tf.train.GradientDescentOptimizer(self.lr_var), None elif self.opt_str == "adam": opt = tf.train.AdamOptimizer(self.lr_var) all_vars = tf.global_variables() opt_vars = [v for v in all_vars if "Adam" in v.name] reset_opt_op = tf.variables_initializer(opt_vars, "reset_optimizer") return opt, reset_opt_op elif self.opt_str == "none": return None, None else: assert False, ("unknown optimizer", self.opt_str) def reset_optimizer(self): assert self.opt_str == "adam", "reset not implemented for other optimizers yet" assert self.reset_opt_op is not None self.session.run(self.reset_opt_op) def init_summaries(self, config, grad_norm=None): summdir = config.dir("summary_dir", "summaries") model = config.string("model") summdir += model + "/" tf.gfile.MakeDirs(summdir) summary_writer = None summary_op = None summary_op_test = None if config.bool("write_summaries", True): summary_writer = tf.summary.FileWriter(summdir, self.session.graph) if self.train_network is not None: train_summs = self.train_network.summaries if grad_norm is not None: grad_norm_summary = tf.summary.scalar("grad_norm", grad_norm) train_summs.append(grad_norm_summary) # better do not merge ALL summaries, since otherwise we get summaries from different networks # and might execute (parts of) the test network while training # self.summary_op = tf.merge_all_summaries() if len(train_summs) > 0: summary_op = tf.summary.merge(train_summs) if self.test_network is not None and len(self.test_network.summaries) > 0: summary_op_test = tf.summary.merge(self.test_network.summaries) return summary_writer, summary_op, summary_op_test def adjust_learning_rate(self, epoch, n_examples_processed_total, learning_rate=None): if learning_rate is None: if self.learning_rate_keys_are_steps: key = max([k for k in self.learning_rates.keys() if k <= n_examples_processed_total + 1]) else: key = max([k for k in self.learning_rates.keys() if k <= epoch + 1]) new_lr = self.learning_rates[key] else: new_lr = learning_rate if self.curr_learning_rate != new_lr: print("changing learning rate to", new_lr, file=log.v1) self.curr_learning_rate = new_lr def create_step_op_and_grad_norm(self): if self.opt is None: return tf.no_op("dummy_step_op"), None losses_with_regularizers = self.train_network.tower_total_losses_with_regularizers setups = self.train_network.tower_setups tower_grads = [] for l, s in zip(losses_with_regularizers, setups): gpu_str = "/gpu:" + str(s.gpu_idx) with tf.device(gpu_str), tf.name_scope("tower_gpu_" + str(s.gpu_idx) + "_opt"): var_list = ( tf.trainable_variables() + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) # TODO: check if gate_gradients=False is safe # TODO: check if colocate_gradients_with_ops is useful grads_raw = self.opt.compute_gradients(l, var_list=var_list, gate_gradients=False, colocate_gradients_with_ops=True) # filter out gradients w.r.t. disconnected variables grads_filtered = [g for g in grads_raw if g[0] is not None] tower_grads.append(grads_filtered) with tf.device(setups[0].variable_device): if len(tower_grads) == 1: grads = tower_grads[0] else: # average the gradients over the towers grads = average_gradients(tower_grads) # grad clipping if self.gradient_clipping != -1: grads, norm = clip_gradients(grads, self.gradient_clipping) else: norm = None if len(grads) == 0: return tf.no_op("dummy_step_op"), None step_op = self.opt.apply_gradients(grads, global_step=self.global_step) return step_op, norm def validation_step(self, epoch=None, n_examples_processed_total=None, feed_dict=None, extraction_keys=()): ops = {Measures.MEASURES: self.test_network.tower_measures} res = self._step(self.test_network, feed_dict, ops, self.summary_op_test, extraction_keys, self.validation_step_number) self.validation_step_number += 1 return res def train_step(self, epoch, n_examples_processed_total=None, feed_dict=None, loss_scale=1.0, learning_rate=None, extraction_keys=()): self.adjust_learning_rate(epoch, n_examples_processed_total, learning_rate) if feed_dict is None: feed_dict = {} else: feed_dict = feed_dict.copy() feed_dict[self.lr_var] = self.curr_learning_rate feed_dict[self.loss_scale_var] = loss_scale ops = {"_update_ops": self._update_ops, "_step": self._step_op, "global_step": self.global_step, Measures.MEASURES: self.train_network.tower_measures} res = self._step(self.train_network, feed_dict, ops, self.summary_op_train, extraction_keys, step_number=None) return res def _step(self, network, feed_dict, ops, summary_op, extraction_keys, step_number): if feed_dict is None: feed_dict = {} if summary_op is not None: ops["summaries"] = summary_op if self.collect_run_metadata: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() else: run_options = None run_metadata = None if len(extraction_keys) > 0: ops[Extractions.EXTRACTIONS] = [{k: [v] for k, v in extractions.items() if k in extraction_keys} for extractions in network.tower_extractions] res = self.session.run(ops, feed_dict=feed_dict, options=run_options, run_metadata=run_metadata) if "summaries" in res: summary_str = res["summaries"] del res["summaries"] else: summary_str = None if step_number is None: step_number = res["global_step"] if step_number % self.summary_interval == 0 and self.summary_writer is not None: if self.collect_run_metadata and step_number > 50: # this is experimental, TODO: make this cleaner # the 50 is to allow for some warmup self.summary_writer.add_run_metadata(run_metadata, tag="timing", global_step=step_number) fetched_timeline = timeline.Timeline(run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open('timing.json', 'w') as f: f.write(chrome_trace) if summary_str is not None: self.summary_writer.add_summary(summary_str, global_step=step_number) res[Measures.MEASURES] = accumulate_measures({}, *res[Measures.MEASURES]) if len(extraction_keys) > 0: res[Extractions.EXTRACTIONS] = accumulate_extractions({}, *res[Extractions.EXTRACTIONS]) return res