Python tensorflow.python.training.monitored_session._HookedSession() Examples

The following are 30 code examples of tensorflow.python.training.monitored_session._HookedSession(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.training.monitored_session , or try the search function .
Example #1
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_save_steps_saves_at_end(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=2 * self.steps_per_run,
          scaffold=self.scaffold)
      hook._set_steps_per_run(self.steps_per_run)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        mon_sess.run(self.train_op)
        hook.end(sess)
        self.assertEqual(
            10, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #2
Source File: monitored_session_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testRunPassesAllArguments(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_run = FakeSession(sess)
      mon_sess = monitored_session._HookedSession(sess=mock_run, hooks=[])
      a_tensor = tf.constant([0], name='a_tensor')
      sess.run(tf.global_variables_initializer())
      output = mon_sess.run(fetches=a_tensor,
                            feed_dict='a_feed',
                            options='an_option',
                            run_metadata='a_metadata')
      self.assertEqual(output, [0])
      self.assertEqual(mock_run.args_called, {
          'feed_dict': 'a_feed',
          'options': 'an_option',
          'run_metadata': 'a_metadata'
      }) 
Example #3
Source File: monitored_session_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testCallsHooksBeginEnd(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      sess.run(tf.global_variables_initializer())
      mon_sess.run(a_tensor)

      for hook in [mock_hook, mock_hook2]:
        self.assertEqual(
            hook.last_run_values, tf.train.SessionRunValues(results=None))
        self.assertEqual(hook.last_run_context.original_args,
                         tf.train.SessionRunArgs(a_tensor))
        self.assertEqual(hook.last_run_context.session, sess)
        self.assertEqual(hook.call_counter['before_run'], 1)
        self.assertEqual(hook.call_counter['after_run'], 1) 
Example #4
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_stop_based_with_multiple_steps(self):
    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)

    with tf.Graph().as_default():
      global_step = tf.compat.v1.train.get_or_create_global_step()
      no_op = tf.no_op()
      h.begin()
      with tf.compat.v1.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(tf.compat.v1.assign(global_step, 5))
        h.after_create_session(sess, None)
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.compat.v1.assign(global_step, 15))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop()) 
Example #5
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_print_at_end_only(self):
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      t = tf.constant(42.0, name='foo')
      train_op = tf.constant(3)
      hook = basic_session_run_hooks.LoggingTensorHook(
          tensors=[t.name], at_end=True)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      self.evaluate(tf.compat.v1.initializers.global_variables())
      self.logged_message = ''
      for _ in range(3):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self.logged_message).find(t.name), -1)

      hook.end(sess)
      self.assertRegexpMatches(str(self.logged_message), t.name) 
Example #6
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_stop_based_on_last_step(self):
    h = basic_session_run_hooks.StopAtStepHook(last_step=10)
    with tf.Graph().as_default():
      global_step = tf.compat.v1.train.get_or_create_global_step()
      no_op = tf.no_op()
      h.begin()
      with tf.compat.v1.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(tf.compat.v1.assign(global_step, 5))
        h.after_create_session(sess, None)
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.compat.v1.assign(global_step, 9))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.compat.v1.assign(global_step, 10))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        sess.run(tf.compat.v1.assign(global_step, 11))
        mon_sess._should_stop = False
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop()) 
Example #7
Source File: metric_hook_test.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def test_print_at_end_only(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      tf.train.get_or_create_global_step()
      t = tf.constant(42.0, name='foo')
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      sess.run(tf.global_variables_initializer())

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 1)
      metric = self._logger.logged_metric[0]
      self.assertRegexpMatches(metric["name"], "foo")
      self.assertEqual(metric["value"], 42.0)
      self.assertEqual(metric["unit"], None)
      self.assertEqual(metric["global_step"], 0) 
Example #8
Source File: monitored_session_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testFetchesHookRequests(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      another_tensor = tf.constant([5], name='another_tensor')
      third_tensor = tf.constant([10], name='third_tensor')
      mock_hook.request = tf.train.SessionRunArgs([another_tensor])
      mock_hook2.request = tf.train.SessionRunArgs([third_tensor])
      sess.run(tf.global_variables_initializer())

      output = mon_sess.run(fetches=a_tensor)
      self.assertEqual(output, [0])
      self.assertEqual(mock_hook.last_run_values.results, [5])
      self.assertEqual(mock_hook2.last_run_values.results, [10]) 
Example #9
Source File: monitored_session_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testHooksAndUserFeedConflicts(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      b_tensor = tf.constant([0], name='b_tensor')
      add_tensor = a_tensor + b_tensor
      mock_hook.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [5]
          })
      mock_hook2.request = tf.train.SessionRunArgs(
          None, feed_dict={
              b_tensor: [10]
          })
      sess.run(tf.global_variables_initializer())

      with self.assertRaisesRegexp(RuntimeError, 'Same tensor is fed'):
        mon_sess.run(fetches=add_tensor, feed_dict={b_tensor: [10]}) 
Example #10
Source File: monitored_session_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testOnlyHooksHaveFeeds(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      b_tensor = tf.constant([0], name='b_tensor')
      add_tensor = a_tensor + b_tensor
      mock_hook.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [5]
          })
      mock_hook2.request = tf.train.SessionRunArgs(
          None, feed_dict={
              b_tensor: [10]
          })
      sess.run(tf.global_variables_initializer())

      self.assertEqual(mon_sess.run(fetches=add_tensor), [15]) 
Example #11
Source File: monitored_session_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testHooksFeedConflicts(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      b_tensor = tf.constant([0], name='b_tensor')
      add_tensor = a_tensor + b_tensor
      mock_hook.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [5]
          })
      mock_hook2.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [10]
          })
      sess.run(tf.global_variables_initializer())

      with self.assertRaisesRegexp(RuntimeError, 'Same tensor is fed'):
        mon_sess.run(fetches=add_tensor) 
Example #12
Source File: basic_session_run_hooks_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def test_stop_based_on_last_step(self):
    h = tf.train.StopAtStepHook(last_step=10)
    with tf.Graph().as_default():
      global_step = tf.contrib.framework.get_or_create_global_step()
      no_op = tf.no_op()
      h.begin()
      with tf.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(tf.assign(global_step, 5))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 9))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 10))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 11))
        mon_sess._should_stop = False
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop()) 
Example #13
Source File: basic_session_run_hooks_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def test_stop_based_on_num_step(self):
    h = tf.train.StopAtStepHook(num_steps=10)

    with tf.Graph().as_default():
      global_step = tf.contrib.framework.get_or_create_global_step()
      no_op = tf.no_op()
      h.begin()
      with tf.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(tf.assign(global_step, 5))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 13))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 14))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        sess.run(tf.assign(global_step, 15))
        mon_sess._should_stop = False
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop()) 
Example #14
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_save_secs_calls_listeners_at_begin_and_end(self):
    with self.graph.as_default():
      listener = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_secs=2,
          scaffold=self.scaffold,
          listeners=[listener])
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)  # hook runs here
        mon_sess.run(self.train_op)  # hook won't run here, so it does at end
        hook.end(sess)  # hook runs here
      self.assertEqual({
          'begin': 1,
          'before_save': 2,
          'after_save': 2,
          'end': 1
      }, listener.get_counts()) 
Example #15
Source File: metric_hook_test.py    From training with Apache License 2.0 6 votes vote down vote up
def test_print_at_end_only(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      tf.train.get_or_create_global_step()
      t = tf.constant(42.0, name='foo')
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      sess.run(tf.global_variables_initializer())

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 1)
      metric = self._logger.logged_metric[0]
      self.assertRegexpMatches(metric["name"], "foo")
      self.assertEqual(metric["value"], 42.0)
      self.assertEqual(metric["unit"], None)
      self.assertEqual(metric["global_step"], 0) 
Example #16
Source File: basic_session_run_hooks_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def test_save_steps(self):
    hook = tf.train.SummarySaverHook(
        save_steps=8,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(tf.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(30):
        mon_sess.run(self.train_op)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {'my_summary': 1.0},
            9: {'my_summary': 2.0},
            17: {'my_summary': 3.0},
            25: {'my_summary': 4.0},
        }) 
Example #17
Source File: basic_session_run_hooks_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def test_save_secs_saving_once_every_step(self):
    hook = tf.train.SummarySaverHook(
        save_secs=0.5,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(tf.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(4):
        mon_sess.run(self.train_op)
        time.sleep(0.5)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {'my_summary': 1.0},
            2: {'my_summary': 2.0},
            3: {'my_summary': 3.0},
            4: {'my_summary': 4.0},
        }) 
Example #18
Source File: basic_session_run_hooks_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def test_save_secs_saving_once_every_three_steps(self):
    hook = tf.train.SummarySaverHook(
        save_secs=0.9,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(tf.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(8):
        mon_sess.run(self.train_op)
        time.sleep(0.3)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {'my_summary': 1.0},
            4: {'my_summary': 2.0},
            7: {'my_summary': 3.0},
        }) 
Example #19
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_log_warning_if_global_step_not_increased(self):
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      tf.compat.v1.train.get_or_create_global_step()
      train_op = training_util._increment_global_step(0)  # keep same.
      self.evaluate(tf.compat.v1.initializers.global_variables())
      hook = basic_session_run_hooks.StepCounterHook(
          every_n_steps=1, every_n_secs=None)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)  # Run one step to record global step.
      with tf.compat.v1.test.mock.patch.object(tf_logging,
                                               'log_first_n') as mock_log:
        for _ in range(30):
          mon_sess.run(train_op)
        self.assertRegexpMatches(
            str(mock_log.call_args), 'global step.*has not been increased')
      hook.end(sess) 
Example #20
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 6 votes vote down vote up
def test_summary_writer_defs(self):
    fake_summary_writer.FakeSummaryWriter.install()
    tf.compat.v1.summary.FileWriterCache.clear()
    summary_writer = tf.compat.v1.summary.FileWriterCache.get(self.model_dir)

    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_steps=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        hook.after_create_session(sess, None)
        mon_sess.run(self.train_op)
      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.model_dir,
          expected_added_meta_graphs=[
              meta_graph.create_meta_graph_def(
                  graph_def=self.graph.as_graph_def(add_shapes=True),
                  saver_def=self.scaffold.saver.saver_def)
          ])

    fake_summary_writer.FakeSummaryWriter.uninstall() 
Example #21
Source File: metric_hook_test.py    From models with Apache License 2.0 6 votes vote down vote up
def test_print_at_end_only(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      tf.train.get_or_create_global_step()
      t = tf.constant(42.0, name='foo')
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t.name], at_end=True, metric_logger=self._logger)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      sess.run(tf.global_variables_initializer())

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 1)
      metric = self._logger.logged_metric[0]
      self.assertRegexpMatches(metric["name"], "foo")
      self.assertEqual(metric["value"], 42.0)
      self.assertEqual(metric["unit"], None)
      self.assertEqual(metric["global_step"], 0) 
Example #22
Source File: basic_session_run_hooks_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def test_print_every_n_steps(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      t = tf.constant(42.0, name='foo')
      train_op = tf.constant(3)
      hook = tf.train.LoggingTensorHook(tensors=[t.name], every_n_iter=10)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      sess.run(tf.global_variables_initializer())
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self.logged_message), t.name)
      for j in range(3):
        _ = j
        self.logged_message = ''
        for i in range(9):
          _ = i
          mon_sess.run(train_op)
          # assertNotRegexpMatches is not supported by python 3.1 and later
          self.assertEqual(str(self.logged_message).find(t.name), -1)
        mon_sess.run(train_op)
        self.assertRegexpMatches(str(self.logged_message), t.name) 
Example #23
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_save_saves_at_end(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_secs=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        mon_sess.run(self.train_op)
        hook.end(sess)
        self.assertEqual(
            2, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #24
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_save_steps_saves_periodically(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_steps=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        mon_sess.run(self.train_op)
        # Not saved
        self.assertEqual(
            1, tf.train.load_variable(self.model_dir, self.global_step.name))
        mon_sess.run(self.train_op)
        # saved
        self.assertEqual(
            3, tf.train.load_variable(self.model_dir, self.global_step.name))
        mon_sess.run(self.train_op)
        # Not saved
        self.assertEqual(
            3, tf.train.load_variable(self.model_dir, self.global_step.name))
        mon_sess.run(self.train_op)
        # saved
        self.assertEqual(
            5, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #25
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_saves_when_saver_and_scaffold_both_missing(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_steps=1)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        self.assertEqual(
            1, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #26
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_save_secs_saves_in_first_step(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_secs=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        self.assertEqual(
            1, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #27
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_print_formatter(self):
    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
      t = tf.constant(42.0, name='foo')
      train_op = tf.constant(3)
      hook = basic_session_run_hooks.LoggingTensorHook(
          tensors=[t.name],
          every_n_iter=10,
          formatter=lambda items: 'qqq=%s' % items[t.name])
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      self.evaluate(tf.compat.v1.initializers.global_variables())
      mon_sess.run(train_op)
      self.assertEqual(self.logged_message[0], 'qqq=42.0') 
Example #28
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_save_steps_saves_in_first_step(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_steps=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        self.assertEqual(
            1, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #29
Source File: basic_session_run_hooks_test.py    From estimator with Apache License 2.0 5 votes vote down vote up
def test_save_steps_saves_in_first_step(self):
    with self.graph.as_default():
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=2 * self.steps_per_run,
          scaffold=self.scaffold)
      hook._set_steps_per_run(self.steps_per_run)
      hook.begin()
      self.scaffold.finalize()
      with tf.compat.v1.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])
        mon_sess.run(self.train_op)
        self.assertEqual(
            5, tf.train.load_variable(self.model_dir, self.global_step.name)) 
Example #30
Source File: inference_runner.py    From tensorpack with Apache License 2.0 5 votes vote down vote up
def _before_train(self):
        super(DataParallelInferenceRunner, self)._before_train()
        self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)