Python tensorflow.python.training.training_util._get_or_create_global_step_read() Examples

The following are 21 code examples of tensorflow.python.training.training_util._get_or_create_global_step_read(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.training.training_util , or try the search function .
Example #1
Source File: hooks.py    From ctc-asr with MIT License 6 votes vote down vote up
def begin(self):
        """Called once before graph finalization.

        Is called once before the default graph in the active tensorflow session is
        finalized and the training has starts.
        The hook can modify the graph by adding new operations to it.
        After the begin() call the graph will be finalized and the other callbacks can not modify
        the graph anymore. Second call of begin() on the same graph, should not change the graph.
        """
        # Create a summary writer if possible.
        if self._summary_writer is None and self._output_dir:
            self._summary_writer = summary_io.SummaryWriterCache.get(self._output_dir)

        # Get read access to the global step tensor.
        # pylint: disable=protected-access
        self._global_step_tensor = training_util._get_or_create_global_step_read()
        if self._global_step_tensor is None:
            raise RuntimeError("Global step should be created to use StepCounterHook.") 
Example #2
Source File: in_memory_eval.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    """Build eval graph and restoring op."""
    self._timer.reset()
    self._graph = ops.Graph()
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    with self._graph.as_default():
      (self._scaffold, self._update_op, self._eval_dict,
       self._all_hooks) = self._estimator._evaluate_build_graph(
           self._input_fn, self._hooks, checkpoint_path=None)

      for h in self._all_hooks:
        if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
          h._should_initialize_tpu = False  # pylint: disable=protected-access

      if self._scaffold.saver is not None:
        raise ValueError('InMemoryEval does not support custom saver')
      if self._scaffold.init_fn is not None:
        raise ValueError('InMemoryEval does not support custom init_fn')

      self._var_name_to_eval_var = {
          v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      self._var_name_to_placeholder = {
          v.name: array_ops.placeholder(v.dtype)
          for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      } 
Example #3
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def begin(self):
    self._worker_is_started = False
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use _GlobalStepWaiterHook.") 
Example #4
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def begin(self):
    if self._summary_writer is None and self._output_dir:
      self._summary_writer = SummaryWriterCache.get(self._output_dir)
    self._next_step = None
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use SummarySaverHook.") 
Example #5
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def begin(self):
    if self._summary_writer is None and self._output_dir:
      self._summary_writer = SummaryWriterCache.get(self._output_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use StepCounterHook.")
    self._summary_tag = training_util.get_global_step().op.name + "/sec" 
Example #6
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #7
Source File: basic_session_run_hooks.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def begin(self):
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError("Global step should be created to use StopAtStepHook.") 
Example #8
Source File: hooks.py    From estimator with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          'Global step should be created to use StopAtCheckpointStepHook.') 
Example #9
Source File: estimator.py    From estimator with Apache License 2.0 5 votes vote down vote up
def _train_model_default(self, input_fn, hooks, saving_listeners):
    """Initiate training with `input_fn`, without `DistributionStrategies`.

    Args:
      input_fn: A function that provides input data for training as minibatches.
      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
        callbacks inside the training loop.
      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      Loss from training
    """
    worker_hooks = []
    with tf.Graph().as_default() as g, g.device(self._device_fn):
      tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)

      # Skip creating a read variable if _create_and_assert_global_step
      # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).
      if global_step_tensor is not None:
        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access

      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(input_fn, ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN,
                                           self.config)
      global_step_tensor = tf.compat.v1.train.get_global_step(g)
      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners) 
Example #10
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #11
Source File: in_memory_eval.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    """Build eval graph and restoring op."""
    self._timer.reset()
    self._graph = ops.Graph()
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    with self._graph.as_default():
      with variable_scope.variable_scope('', use_resource=True):
        training_util.get_or_create_global_step()
      features, input_hooks = self._estimator._get_features_from_input_fn(  # pylint: disable=protected-access
          self._input_fn, model_fn_lib.ModeKeys.PREDICT)
      estimator_spec = self._estimator._call_model_fn(  # pylint: disable=protected-access
          features, None, model_fn_lib.ModeKeys.PREDICT, self._estimator.config)

      self._all_hooks = list(input_hooks) + list(estimator_spec.prediction_hooks)
      self._predictions = self._estimator._extract_keys(  # pylint: disable=protected-access
          estimator_spec.predictions,
          predict_keys=None)
      self._var_name_to_eval_var = {
          v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      self._var_name_to_placeholder = {
          v.name: array_ops.placeholder(v.dtype)
          for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      logging.info('Placeholders: %s', self._var_name_to_placeholder)

      for h in self._all_hooks:
        logging.info('Hook: %s', h)
        if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
          h._should_initialize_tpu = False  # pylint: disable=protected-access 
Example #12
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #13
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #14
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #15
Source File: in_memory_eval.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    """Build eval graph and restoring op."""
    self._timer.reset()
    self._graph = ops.Graph()
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    with self._graph.as_default():
      with variable_scope.variable_scope('', use_resource=True):
        training_util.get_or_create_global_step()
      features, input_hooks = self._estimator._get_features_from_input_fn(  # pylint: disable=protected-access
          self._input_fn, model_fn_lib.ModeKeys.PREDICT)
      estimator_spec = self._estimator._call_model_fn(  # pylint: disable=protected-access
          features, None, model_fn_lib.ModeKeys.PREDICT, self._estimator.config)

      self._all_hooks = list(input_hooks) + list(estimator_spec.prediction_hooks)
      self._predictions = self._estimator._extract_keys(  # pylint: disable=protected-access
          estimator_spec.predictions,
          predict_keys=None)
      self._var_name_to_eval_var = {
          v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      self._var_name_to_placeholder = {
          v.name: array_ops.placeholder(v.dtype)
          for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      logging.info('Placeholders: %s', self._var_name_to_placeholder)

      for h in self._all_hooks:
        logging.info('Hook: %s', h)
        if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
          h._should_initialize_tpu = False  # pylint: disable=protected-access 
Example #16
Source File: in_memory_eval.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    """Build eval graph and restoring op."""
    self._timer.reset()
    self._graph = ops.Graph()
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    with self._graph.as_default():
      (self._scaffold, self._update_op, self._eval_dict,
       self._all_hooks) = self._estimator._evaluate_build_graph(
           self._input_fn, self._hooks, checkpoint_path=None)

      for h in self._all_hooks:
        if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
          h._should_initialize_tpu = False  # pylint: disable=protected-access

      if self._scaffold.saver is not None:
        raise ValueError('InMemoryEval does not support custom saver')
      if self._scaffold.init_fn is not None:
        raise ValueError('InMemoryEval does not support custom init_fn')

      self._var_name_to_eval_var = {
          v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      self._var_name_to_placeholder = {
          v.name: array_ops.placeholder(v.dtype)
          for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      } 
Example #17
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #18
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin() 
Example #19
Source File: in_memory_eval.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    """Build eval graph and restoring op."""
    self._timer.reset()
    self._graph = ops.Graph()
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    with self._graph.as_default():
      with variable_scope.variable_scope('', use_resource=True):
        training_util.get_or_create_global_step()
      features, input_hooks = self._estimator._get_features_from_input_fn(  # pylint: disable=protected-access
          self._input_fn, model_fn_lib.ModeKeys.PREDICT)
      estimator_spec = self._estimator._call_model_fn(  # pylint: disable=protected-access
          features, None, model_fn_lib.ModeKeys.PREDICT, self._estimator.config)

      self._all_hooks = list(input_hooks) + list(estimator_spec.prediction_hooks)
      self._predictions = self._estimator._extract_keys(  # pylint: disable=protected-access
          estimator_spec.predictions,
          predict_keys=None)
      self._var_name_to_eval_var = {
          v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      self._var_name_to_placeholder = {
          v.name: array_ops.placeholder(v.dtype)
          for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      logging.info('Placeholders: %s', self._var_name_to_placeholder)

      for h in self._all_hooks:
        logging.info('Hook: %s', h)
        if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
          h._should_initialize_tpu = False  # pylint: disable=protected-access 
Example #20
Source File: in_memory_eval.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    """Build eval graph and restoring op."""
    self._timer.reset()
    self._graph = ops.Graph()
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    with self._graph.as_default():
      (self._scaffold, self._update_op, self._eval_dict,
       self._all_hooks) = self._estimator._evaluate_build_graph(
           self._input_fn, self._hooks, checkpoint_path=None)

      for h in self._all_hooks:
        if isinstance(h, tpu_estimator.TPUInfeedOutfeedSessionHook):
          h._should_initialize_tpu = False  # pylint: disable=protected-access

      if self._scaffold.saver is not None:
        raise ValueError('InMemoryEval does not support custom saver')
      if self._scaffold.init_fn is not None:
        raise ValueError('InMemoryEval does not support custom init_fn')

      self._var_name_to_eval_var = {
          v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      }
      self._var_name_to_placeholder = {
          v.name: array_ops.placeholder(v.dtype)
          for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      } 
Example #21
Source File: async_checkpoint.py    From training_results_v0.5 with Apache License 2.0 5 votes vote down vote up
def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin()