import os import logging import tensorflow as tf from tensorflow.python.client import timeline from tensorflow.python.training import training_util from tensorflow.python.training import session_run_hook class ProfileAtStepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step.""" def __init__(self, at_step=None, checkpoint_dir=None, trace_level=tf.RunOptions.FULL_TRACE): self._at_step = at_step self._do_profile = False self._writer = tf.summary.FileWriter(checkpoint_dir) self._trace_level = trace_level def begin(self): self._global_step_tensor = tf.train.get_global_step() if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use ProfileAtStepHook.") def before_run(self, run_context): # pylint: disable=unused-argument if self._do_profile: options = tf.RunOptions(trace_level=self._trace_level) else: options = None return tf.train.SessionRunArgs(self._global_step_tensor, options=options) def after_run(self, run_context, run_values): global_step = run_values.results - 1 if self._do_profile: self._do_profile = False self._writer.add_run_metadata(run_values.run_metadata, 'trace_{}'.format(global_step), global_step) timeline_object = timeline.Timeline(run_values.run_metadata.step_stats) chrome_trace = timeline_object.generate_chrome_trace_format() chrome_trace_save_path = 'timeline_{}.json'.format(global_step) with open(chrome_trace_save_path, 'w') as f: f.write(chrome_trace) logging.info('Profile trace saved to {}'.format(chrome_trace_save_path)) if global_step == self._at_step: self._do_profile = True