Python tensorflow.python.training.saver.latest_checkpoint() Examples

The following are 30 code examples of tensorflow.python.training.saver.latest_checkpoint(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.training.saver , or try the search function .
Example #1
Source File: monitors.py    From keras-lambda with MIT License 6 votes vote down vote up
def end(self, session=None):
    super(ExportMonitor, self).end(session=session)
    latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
    if latest_path is None:
      logging.info("Skipping export at the end since model has not been saved "
                   "yet.")
      return
    try:
      self._last_export_dir = self._estimator.export(
          self.export_dir,
          exports_to_keep=self.exports_to_keep,
          signature_fn=self.signature_fn,
          input_fn=self._input_fn,
          default_batch_size=self._default_batch_size,
          input_feature_key=self._input_feature_key,
          use_deprecated_input_fn=self._use_deprecated_input_fn)
    except RuntimeError:
      logging.info("Skipping exporting for the same step.") 
Example #2
Source File: plugin.py    From keras-lambda with MIT License 6 votes vote down vote up
def _latest_checkpoints_changed(configs, run_path_pairs):
  """Returns true if the latest checkpoint has changed in any of the runs."""
  for run_name, logdir in run_path_pairs:
    if run_name not in configs:
      config = ProjectorConfig()
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if file_io.file_exists(config_fpath):
        file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
        text_format.Merge(file_content, config)
    else:
      config = configs[run_name]

    # See if you can find a checkpoint file in the logdir.
    ckpt_path = latest_checkpoint(logdir)
    if not ckpt_path:
      # See if you can find a checkpoint in the parent of logdir.
      ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir))
      if not ckpt_path:
        continue
    if config.model_checkpoint_path != ckpt_path:
      return True
  return False 
Example #3
Source File: monitors.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def end(self, session=None):
    super(ExportMonitor, self).end(session=session)
    latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
    if latest_path is None:
      logging.info("Skipping export at the end since model has not been saved "
                   "yet.")
      return
    try:
      self._last_export_dir = self._estimator.export(
          self.export_dir,
          exports_to_keep=self.exports_to_keep,
          signature_fn=self.signature_fn,
          input_fn=self._input_fn,
          default_batch_size=self._default_batch_size,
          input_feature_key=self._input_feature_key,
          use_deprecated_input_fn=self._use_deprecated_input_fn)
    except RuntimeError:
      logging.info("Skipping exporting for the same step.") 
Example #4
Source File: plugin.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _latest_checkpoints_changed(configs, run_path_pairs):
  """Returns true if the latest checkpoint has changed in any of the runs."""
  for run_name, logdir in run_path_pairs:
    if run_name not in configs:
      config = ProjectorConfig()
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if file_io.file_exists(config_fpath):
        file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
        text_format.Merge(file_content, config)
    else:
      config = configs[run_name]

    # See if you can find a checkpoint file in the logdir.
    ckpt_path = latest_checkpoint(logdir)
    if not ckpt_path:
      # See if you can find a checkpoint in the parent of logdir.
      ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir))
      if not ckpt_path:
        continue
    if config.model_checkpoint_path != ckpt_path:
      return True
  return False 
Example #5
Source File: monitors.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def end(self, session=None):
    super(ExportMonitor, self).end(session=session)
    latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
    if latest_path is None:
      logging.info("Skipping export at the end since model has not been saved "
                   "yet.")
      return
    try:
      self._last_export_dir = self._estimator.export(
          self.export_dir,
          exports_to_keep=self.exports_to_keep,
          signature_fn=self.signature_fn,
          input_fn=self._input_fn,
          default_batch_size=self._default_batch_size,
          input_feature_key=self._input_feature_key,
          use_deprecated_input_fn=self._use_deprecated_input_fn)
    except RuntimeError:
      logging.info("Skipping exporting for the same step.") 
Example #6
Source File: plugin.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def _latest_checkpoints_changed(configs, run_path_pairs):
  """Returns true if the latest checkpoint has changed in any of the runs."""
  for run_name, logdir in run_path_pairs:
    if run_name not in configs:
      continue
    config = configs[run_name]
    if not config.model_checkpoint_path:
      continue

    # See if you can find a checkpoint file in the logdir.
    ckpt_path = latest_checkpoint(logdir)
    if not ckpt_path:
      # See if you can find a checkpoint in the parent of logdir.
      ckpt_path = latest_checkpoint(os.path.join('../', logdir))
      if not ckpt_path:
        continue
    if config.model_checkpoint_path != ckpt_path:
      return True
  return False 
Example #7
Source File: im_text_rnn_model.py    From tumblr-emotions with Apache License 2.0 5 votes vote down vote up
def correlation_matrix(nb_batches, checkpoint_dir):
    """Computes logits and labels of the input posts and save them as numpy files.
    
    Parameters:
        checkpoint_dir: Checkpoint of the saved model during training.
    """
    with tf.Graph().as_default():
        config = _CONFIG.copy()
        config['mode'] = 'validation'
        model = DeepSentiment(config)

        # Load model
        checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
        scaffold = monitored_session.Scaffold(
            init_op=None, init_feed_dict=None,
            init_fn=None, saver=None)
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_filename_with_path=checkpoint_path,
            master='',
            config=None)

        posts_logits = []
        posts_labels = []
        with monitored_session.MonitoredSession( # Generate queue
            session_creator=session_creator, hooks=None) as session:
            for i in range(nb_batches):
                np_logits, np_labels = session.run([model.logits, model.labels])
                posts_logits.append(np_logits)
                posts_labels.append(np_labels)

    posts_logits, posts_labels = np.vstack(posts_logits), np.hstack(posts_labels)
    np.save('data/posts_logits.npy', posts_logits)
    np.save('data/posts_labels.npy', posts_labels)
    return posts_logits, posts_labels 
Example #8
Source File: hooks.py    From NJUNMT-tf with Apache License 2.0 5 votes vote down vote up
def after_create_session(self, session, coord):
        checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir)
        if checkpoint_path:
            # reloading model
            self._saver.restore(session, checkpoint_path)
            gs = session.run(self._global_step)
            tf.logging.info(
                "CheckpointSaverHook (after_create_session): reloading models and reset global_step={}".format(gs))
            StepTimer.reset_init_triggered_step(gs)
        elif self._reload_var_ops:
            tf.logging.info("Assign all variables with pretrained variables.")
            session.run(self._reload_var_ops) 
Example #9
Source File: estimator_v2.py    From boxnet with GNU General Public License v3.0 5 votes vote down vote up
def latest_checkpoint(self):
    """Finds the filename of latest saved checkpoint file in `model_dir`.
    Returns:
      The full path to the latest checkpoint or `None` if no checkpoint was
      found.
    """
    return saver.latest_checkpoint(self.model_dir) 
Example #10
Source File: estimator_v2.py    From boxnet with GNU General Public License v3.0 5 votes vote down vote up
def _check_checkpoint_available(model_dir):
  latest_path = saver.latest_checkpoint(model_dir)
  if not latest_path:
    raise ValueError(
        'Could not find trained model in model_dir: {}.'.format(model_dir)) 
Example #11
Source File: estimator_v2.py    From boxnet with GNU General Public License v3.0 5 votes vote down vote up
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
  try:
    checkpoint_reader = training.NewCheckpointReader(
        training.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
  except:  # pylint: disable=bare-except
    return 0 
Example #12
Source File: checkpoint_utils.py    From human-rl with MIT License 5 votes vote down vote up
def _get_checkpoint_filename(filepattern):
  """Returns checkpoint filename given directory or specific filepattern."""
  if gfile.IsDirectory(filepattern):
    return saver.latest_checkpoint(filepattern)
  return filepattern 
Example #13
Source File: checkpoint_utils.py    From human-rl with MIT License 5 votes vote down vote up
def _get_checkpoint_filename(filepattern):
  """Returns checkpoint filename given directory or specific filepattern."""
  if gfile.IsDirectory(filepattern):
    return saver.latest_checkpoint(filepattern)
  return filepattern 
Example #14
Source File: tf_utils.py    From sgnmt with Apache License 2.0 5 votes vote down vote up
def create_session(checkpoint_path, n_cpu_threads=-1):
    """Creates a MonitoredSession.
    
    Args:
      checkpoint_path (string): Path either to checkpoint directory or
                                directly to a checkpoint file.
      n_cpu_threads (int): Number of CPU threads. If negative, we
                           assume either GPU decoding or that all
                           CPU cores can be used.
    Returns:
      A TensorFlow MonitoredSession.
    """
    try:
        if os.path.isdir(checkpoint_path):
            checkpoint_path = saver.latest_checkpoint(checkpoint_path)
        else:
            logging.info("%s is not a directory. Interpreting as direct "
                         "path to checkpoint..." % checkpoint_path)
        return training.MonitoredSession(
            session_creator=training.ChiefSessionCreator(
                checkpoint_filename_with_path=checkpoint_path,
                config=session_config(n_cpu_threads)))
    except tf.errors.NotFoundError as e:
        logging.fatal("Could not find all variables of the computation "
            "graph in the T2T checkpoint file. This means that the "
            "checkpoint does not correspond to the model specified in "
            "SGNMT. Please double-check pred_src_vocab_size, "
            "pred_trg_vocab_size, and all the t2t_* parameters. "
            "Also make sure that the checkpoint exists and is readable")
        raise AttributeError("Could not initialize TF session.") 
Example #15
Source File: tf_nizza.py    From sgnmt with Apache License 2.0 5 votes vote down vote up
def create_session(self, checkpoint_dir):
        """Creates a MonitoredSession for this predictor."""
        checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
        return training.MonitoredSession(
            session_creator=training.ChiefSessionCreator(
                checkpoint_filename_with_path=checkpoint_path,
                config=self._session_config())) 
Example #16
Source File: estimator.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _save_first_checkpoint(keras_model, estimator, custom_objects,
                           keras_weights):
  """Save first checkpoint for the keras Estimator.

  Args:
    keras_model: an instance of compiled keras model.
    estimator: keras estimator.
    custom_objects: Dictionary for custom objects.
    keras_weights: A flat list of Numpy arrays for weights of given keras_model.

  Returns:
    The model_fn for a keras Estimator.
  """
  with ops.Graph().as_default() as g, g.device(estimator._device_fn):
    random_seed.set_random_seed(estimator.config.tf_random_seed)
    training_util.create_global_step()
    model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
                                   custom_objects)

    if isinstance(model, models.Sequential):
      model = model.model
    # Load weights and save to checkpoint if there is no checkpoint
    latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
    if not latest_path:
      with session.Session() as sess:
        model.set_weights(keras_weights)
        # Make update ops and initialize all variables.
        if not model.train_function:
          # pylint: disable=protected-access
          model._make_train_function()
          K._initialize_variables(sess)
          # pylint: enable=protected-access
        saver = saver_lib.Saver()
        saver.save(sess, estimator.model_dir + '/') 
Example #17
Source File: estimator.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def latest_checkpoint(self):
    """Finds the filename of latest saved checkpoint file in `model_dir`.

    Returns:
      The full path to the latest checkpoint or `None` if no checkpoint was
      found.
    """
    return saver.latest_checkpoint(self.model_dir) 
Example #18
Source File: estimator.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _check_checkpoint_available(model_dir):
  latest_path = saver.latest_checkpoint(model_dir)
  if not latest_path:
    raise ValueError(
        'Could not find trained model in model_dir: {}.'.format(model_dir)) 
Example #19
Source File: estimator.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
  try:
    checkpoint_reader = training.NewCheckpointReader(
        training.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
  except:  # pylint: disable=bare-except
    return 0 
Example #20
Source File: plugin.py    From keras-lambda with MIT License 5 votes vote down vote up
def _read_latest_config_files(self, run_path_pairs):
    """Reads and returns the projector config files in every run directory."""
    configs = {}
    config_fpaths = {}
    for run_name, logdir in run_path_pairs:
      config = ProjectorConfig()
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if file_io.file_exists(config_fpath):
        file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
        text_format.Merge(file_content, config)

      has_tensor_files = False
      for embedding in config.embeddings:
        if embedding.tensor_path:
          has_tensor_files = True
          break

      if not config.model_checkpoint_path:
        # See if you can find a checkpoint file in the logdir.
        ckpt_path = latest_checkpoint(logdir)
        if not ckpt_path:
          # Or in the parent of logdir.
          ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir))
          if not ckpt_path and not has_tensor_files:
            continue
        if ckpt_path:
          config.model_checkpoint_path = ckpt_path

      # Sanity check for the checkpoint file.
      if (config.model_checkpoint_path and
          not checkpoint_exists(config.model_checkpoint_path)):
        logging.warning('Checkpoint file %s not found',
                        config.model_checkpoint_path)
        continue
      configs[run_name] = config
      config_fpaths[run_name] = config_fpath
    return configs, config_fpaths 
Example #21
Source File: estimator.py    From keras-lambda with MIT License 5 votes vote down vote up
def _infer_model(self,
                   input_fn,
                   feed_fn=None,
                   outputs=None,
                   as_iterable=True,
                   iterate_batches=False):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      infer_ops = self._call_legacy_get_predict_ops(features)
      predictions = self._filter_predictions(infer_ops.predictions, outputs)
      mon_sess = monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              checkpoint_filename_with_path=checkpoint_path))
      if not as_iterable:
        with mon_sess:
          if not mon_sess.should_stop():
            return mon_sess.run(predictions, feed_fn() if feed_fn else None)
      else:
        return self._predict_generator(mon_sess, predictions, feed_fn,
                                       iterate_batches) 
Example #22
Source File: checkpoint_utils.py    From keras-lambda with MIT License 5 votes vote down vote up
def _get_checkpoint_filename(filepattern):
  """Returns checkpoint filename given directory or specific filepattern."""
  if gfile.IsDirectory(filepattern):
    return saver.latest_checkpoint(filepattern)
  return filepattern 
Example #23
Source File: estimator.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _infer_model(self,
                   input_fn,
                   feed_fn=None,
                   outputs=None,
                   as_iterable=True,
                   iterate_batches=False):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      infer_ops = self._call_legacy_get_predict_ops(features)
      predictions = self._filter_predictions(infer_ops.predictions, outputs)
      mon_sess = monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              checkpoint_filename_with_path=checkpoint_path))
      if not as_iterable:
        with mon_sess:
          if not mon_sess.should_stop():
            return mon_sess.run(predictions, feed_fn() if feed_fn else None)
      else:
        return self._predict_generator(mon_sess, predictions, feed_fn,
                                       iterate_batches) 
Example #24
Source File: estimator.py    From lambda-packs with MIT License 5 votes vote down vote up
def _load_global_step_from_checkpoint_dir(checkpoint_dir):
  try:
    checkpoint_reader = training.NewCheckpointReader(
        training.latest_checkpoint(checkpoint_dir))
    return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)
  except:  # pylint: disable=bare-except
    return 0 
Example #25
Source File: checkpoint_utils.py    From lambda-packs with MIT License 5 votes vote down vote up
def _get_checkpoint_filename(ckpt_dir_or_file):
  """Returns checkpoint filename given directory or specific checkpoint file."""
  if gfile.IsDirectory(ckpt_dir_or_file):
    return saver.latest_checkpoint(ckpt_dir_or_file)
  return ckpt_dir_or_file 
Example #26
Source File: projector_plugin.py    From lambda-packs with MIT License 5 votes vote down vote up
def _find_latest_checkpoint(dir_path):
  try:
    ckpt_path = latest_checkpoint(dir_path)
    if not ckpt_path:
      # Check the parent directory.
      ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir))
    return ckpt_path
  except errors.NotFoundError:
    return None 
Example #27
Source File: estimator.py    From lambda-packs with MIT License 5 votes vote down vote up
def _infer_model(self,
                   input_fn,
                   feed_fn=None,
                   outputs=None,
                   as_iterable=True,
                   iterate_batches=False):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      infer_ops = self._get_predict_ops(features)
      predictions = self._filter_predictions(infer_ops.predictions, outputs)
      mon_sess = monitored_session.MonitoredSession(
          session_creator=monitored_session.ChiefSessionCreator(
              checkpoint_filename_with_path=checkpoint_path,
              scaffold=infer_ops.scaffold,
              config=self._session_config))
      if not as_iterable:
        with mon_sess:
          if not mon_sess.should_stop():
            return mon_sess.run(predictions, feed_fn() if feed_fn else None)
      else:
        return self._predict_generator(mon_sess, predictions, feed_fn,
                                       iterate_batches) 
Example #28
Source File: evaluation.py    From lambda-packs with MIT License 5 votes vote down vote up
def wait_for_new_checkpoint(checkpoint_dir,
                            last_checkpoint=None,
                            seconds_to_sleep=1,
                            timeout=None):
  """Waits until a new checkpoint file is found.

  Args:
    checkpoint_dir: The directory in which checkpoints are saved.
    last_checkpoint: The last checkpoint path used or `None` if we're expecting
      a checkpoint for the first time.
    seconds_to_sleep: The number of seconds to sleep for before looking for a
      new checkpoint.
    timeout: The maximum amount of time to wait. If left as `None`, then the
      process will wait indefinitely.

  Returns:
    a new checkpoint path, or None if the timeout was reached.
  """
  logging.info('Waiting for new checkpoint at %s', checkpoint_dir)
  stop_time = time.time() + timeout if timeout is not None else None
  while True:
    checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir)
    if checkpoint_path is None or checkpoint_path == last_checkpoint:
      if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
        return None
      time.sleep(seconds_to_sleep)
    else:
      logging.info('Found new checkpoint at %s', checkpoint_path)
      return checkpoint_path 
Example #29
Source File: checkpoint_utils.py    From lambda-packs with MIT License 5 votes vote down vote up
def _get_checkpoint_filename(filepattern):
  """Returns checkpoint filename given directory or specific filepattern."""
  if gfile.IsDirectory(filepattern):
    return saver.latest_checkpoint(filepattern)
  return filepattern 
Example #30
Source File: plugin.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _read_latest_config_files(self, run_path_pairs):
    """Reads and returns the projector config files in every run directory."""
    configs = {}
    config_fpaths = {}
    for run_name, logdir in run_path_pairs:
      config = ProjectorConfig()
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if file_io.file_exists(config_fpath):
        file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
        text_format.Merge(file_content, config)

      has_tensor_files = False
      for embedding in config.embeddings:
        if embedding.tensor_path:
          has_tensor_files = True
          break

      if not config.model_checkpoint_path:
        # See if you can find a checkpoint file in the logdir.
        ckpt_path = latest_checkpoint(logdir)
        if not ckpt_path:
          # Or in the parent of logdir.
          ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir))
          if not ckpt_path and not has_tensor_files:
            continue
        if ckpt_path:
          config.model_checkpoint_path = ckpt_path

      # Sanity check for the checkpoint file.
      if (config.model_checkpoint_path and
          not checkpoint_exists(config.model_checkpoint_path)):
        logging.warning('Checkpoint file %s not found',
                        config.model_checkpoint_path)
        continue
      configs[run_name] = config
      config_fpaths[run_name] = config_fpath
    return configs, config_fpaths