Python tensorflow.python.training.saver.latest_checkpoint() Examples
The following are 30
code examples of tensorflow.python.training.saver.latest_checkpoint().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.training.saver
, or try the search function
.
Example #1
Source File: monitors.py From keras-lambda with MIT License | 6 votes |
def end(self, session=None): super(ExportMonitor, self).end(session=session) latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") return try: self._last_export_dir = self._estimator.export( self.export_dir, exports_to_keep=self.exports_to_keep, signature_fn=self.signature_fn, input_fn=self._input_fn, default_batch_size=self._default_batch_size, input_feature_key=self._input_feature_key, use_deprecated_input_fn=self._use_deprecated_input_fn) except RuntimeError: logging.info("Skipping exporting for the same step.")
Example #2
Source File: plugin.py From keras-lambda with MIT License | 6 votes |
def _latest_checkpoints_changed(configs, run_path_pairs): """Returns true if the latest checkpoint has changed in any of the runs.""" for run_name, logdir in run_path_pairs: if run_name not in configs: config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') text_format.Merge(file_content, config) else: config = configs[run_name] # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # See if you can find a checkpoint in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir)) if not ckpt_path: continue if config.model_checkpoint_path != ckpt_path: return True return False
Example #3
Source File: monitors.py From deep_image_model with Apache License 2.0 | 6 votes |
def end(self, session=None): super(ExportMonitor, self).end(session=session) latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") return try: self._last_export_dir = self._estimator.export( self.export_dir, exports_to_keep=self.exports_to_keep, signature_fn=self.signature_fn, input_fn=self._input_fn, default_batch_size=self._default_batch_size, input_feature_key=self._input_feature_key, use_deprecated_input_fn=self._use_deprecated_input_fn) except RuntimeError: logging.info("Skipping exporting for the same step.")
Example #4
Source File: plugin.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _latest_checkpoints_changed(configs, run_path_pairs): """Returns true if the latest checkpoint has changed in any of the runs.""" for run_name, logdir in run_path_pairs: if run_name not in configs: config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') text_format.Merge(file_content, config) else: config = configs[run_name] # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # See if you can find a checkpoint in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir)) if not ckpt_path: continue if config.model_checkpoint_path != ckpt_path: return True return False
Example #5
Source File: monitors.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def end(self, session=None): super(ExportMonitor, self).end(session=session) latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: logging.info("Skipping export at the end since model has not been saved " "yet.") return try: self._last_export_dir = self._estimator.export( self.export_dir, exports_to_keep=self.exports_to_keep, signature_fn=self.signature_fn, input_fn=self._input_fn, default_batch_size=self._default_batch_size, input_feature_key=self._input_feature_key, use_deprecated_input_fn=self._use_deprecated_input_fn) except RuntimeError: logging.info("Skipping exporting for the same step.")
Example #6
Source File: plugin.py From deep_image_model with Apache License 2.0 | 6 votes |
def _latest_checkpoints_changed(configs, run_path_pairs): """Returns true if the latest checkpoint has changed in any of the runs.""" for run_name, logdir in run_path_pairs: if run_name not in configs: continue config = configs[run_name] if not config.model_checkpoint_path: continue # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # See if you can find a checkpoint in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join('../', logdir)) if not ckpt_path: continue if config.model_checkpoint_path != ckpt_path: return True return False
Example #7
Source File: im_text_rnn_model.py From tumblr-emotions with Apache License 2.0 | 5 votes |
def correlation_matrix(nb_batches, checkpoint_dir): """Computes logits and labels of the input posts and save them as numpy files. Parameters: checkpoint_dir: Checkpoint of the saved model during training. """ with tf.Graph().as_default(): config = _CONFIG.copy() config['mode'] = 'validation' model = DeepSentiment(config) # Load model checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) scaffold = monitored_session.Scaffold( init_op=None, init_feed_dict=None, init_fn=None, saver=None) session_creator = monitored_session.ChiefSessionCreator( scaffold=scaffold, checkpoint_filename_with_path=checkpoint_path, master='', config=None) posts_logits = [] posts_labels = [] with monitored_session.MonitoredSession( # Generate queue session_creator=session_creator, hooks=None) as session: for i in range(nb_batches): np_logits, np_labels = session.run([model.logits, model.labels]) posts_logits.append(np_logits) posts_labels.append(np_labels) posts_logits, posts_labels = np.vstack(posts_logits), np.hstack(posts_labels) np.save('data/posts_logits.npy', posts_logits) np.save('data/posts_labels.npy', posts_labels) return posts_logits, posts_labels
Example #8
Source File: hooks.py From NJUNMT-tf with Apache License 2.0 | 5 votes |
def after_create_session(self, session, coord): checkpoint_path = saver_lib.latest_checkpoint(self._checkpoint_dir) if checkpoint_path: # reloading model self._saver.restore(session, checkpoint_path) gs = session.run(self._global_step) tf.logging.info( "CheckpointSaverHook (after_create_session): reloading models and reset global_step={}".format(gs)) StepTimer.reset_init_triggered_step(gs) elif self._reload_var_ops: tf.logging.info("Assign all variables with pretrained variables.") session.run(self._reload_var_ops)
Example #9
Source File: estimator_v2.py From boxnet with GNU General Public License v3.0 | 5 votes |
def latest_checkpoint(self): """Finds the filename of latest saved checkpoint file in `model_dir`. Returns: The full path to the latest checkpoint or `None` if no checkpoint was found. """ return saver.latest_checkpoint(self.model_dir)
Example #10
Source File: estimator_v2.py From boxnet with GNU General Public License v3.0 | 5 votes |
def _check_checkpoint_available(model_dir): latest_path = saver.latest_checkpoint(model_dir) if not latest_path: raise ValueError( 'Could not find trained model in model_dir: {}.'.format(model_dir))
Example #11
Source File: estimator_v2.py From boxnet with GNU General Public License v3.0 | 5 votes |
def _load_global_step_from_checkpoint_dir(checkpoint_dir): try: checkpoint_reader = training.NewCheckpointReader( training.latest_checkpoint(checkpoint_dir)) return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP) except: # pylint: disable=bare-except return 0
Example #12
Source File: checkpoint_utils.py From human-rl with MIT License | 5 votes |
def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): return saver.latest_checkpoint(filepattern) return filepattern
Example #13
Source File: checkpoint_utils.py From human-rl with MIT License | 5 votes |
def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): return saver.latest_checkpoint(filepattern) return filepattern
Example #14
Source File: tf_utils.py From sgnmt with Apache License 2.0 | 5 votes |
def create_session(checkpoint_path, n_cpu_threads=-1): """Creates a MonitoredSession. Args: checkpoint_path (string): Path either to checkpoint directory or directly to a checkpoint file. n_cpu_threads (int): Number of CPU threads. If negative, we assume either GPU decoding or that all CPU cores can be used. Returns: A TensorFlow MonitoredSession. """ try: if os.path.isdir(checkpoint_path): checkpoint_path = saver.latest_checkpoint(checkpoint_path) else: logging.info("%s is not a directory. Interpreting as direct " "path to checkpoint..." % checkpoint_path) return training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, config=session_config(n_cpu_threads))) except tf.errors.NotFoundError as e: logging.fatal("Could not find all variables of the computation " "graph in the T2T checkpoint file. This means that the " "checkpoint does not correspond to the model specified in " "SGNMT. Please double-check pred_src_vocab_size, " "pred_trg_vocab_size, and all the t2t_* parameters. " "Also make sure that the checkpoint exists and is readable") raise AttributeError("Could not initialize TF session.")
Example #15
Source File: tf_nizza.py From sgnmt with Apache License 2.0 | 5 votes |
def create_session(self, checkpoint_dir): """Creates a MonitoredSession for this predictor.""" checkpoint_path = saver.latest_checkpoint(checkpoint_dir) return training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, config=self._session_config()))
Example #16
Source File: estimator.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def _save_first_checkpoint(keras_model, estimator, custom_objects, keras_weights): """Save first checkpoint for the keras Estimator. Args: keras_model: an instance of compiled keras model. estimator: keras estimator. custom_objects: Dictionary for custom objects. keras_weights: A flat list of Numpy arrays for weights of given keras_model. Returns: The model_fn for a keras Estimator. """ with ops.Graph().as_default() as g, g.device(estimator._device_fn): random_seed.set_random_seed(estimator.config.tf_random_seed) training_util.create_global_step() model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, custom_objects) if isinstance(model, models.Sequential): model = model.model # Load weights and save to checkpoint if there is no checkpoint latest_path = saver_lib.latest_checkpoint(estimator.model_dir) if not latest_path: with session.Session() as sess: model.set_weights(keras_weights) # Make update ops and initialize all variables. if not model.train_function: # pylint: disable=protected-access model._make_train_function() K._initialize_variables(sess) # pylint: enable=protected-access saver = saver_lib.Saver() saver.save(sess, estimator.model_dir + '/')
Example #17
Source File: estimator.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def latest_checkpoint(self): """Finds the filename of latest saved checkpoint file in `model_dir`. Returns: The full path to the latest checkpoint or `None` if no checkpoint was found. """ return saver.latest_checkpoint(self.model_dir)
Example #18
Source File: estimator.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def _check_checkpoint_available(model_dir): latest_path = saver.latest_checkpoint(model_dir) if not latest_path: raise ValueError( 'Could not find trained model in model_dir: {}.'.format(model_dir))
Example #19
Source File: estimator.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def _load_global_step_from_checkpoint_dir(checkpoint_dir): try: checkpoint_reader = training.NewCheckpointReader( training.latest_checkpoint(checkpoint_dir)) return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP) except: # pylint: disable=bare-except return 0
Example #20
Source File: plugin.py From keras-lambda with MIT License | 5 votes |
def _read_latest_config_files(self, run_path_pairs): """Reads and returns the projector config files in every run directory.""" configs = {} config_fpaths = {} for run_name, logdir in run_path_pairs: config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: if embedding.tensor_path: has_tensor_files = True break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # Or in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir)) if not ckpt_path and not has_tensor_files: continue if ckpt_path: config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. if (config.model_checkpoint_path and not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file %s not found', config.model_checkpoint_path) continue configs[run_name] = config config_fpaths[run_name] = config_fpath return configs, config_fpaths
Example #21
Source File: estimator.py From keras-lambda with MIT License | 5 votes |
def _infer_model(self, input_fn, feed_fn=None, outputs=None, as_iterable=True, iterate_batches=False): # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError("Couldn't find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) contrib_framework.create_global_step(g) features = self._get_features_from_input_fn(input_fn) infer_ops = self._call_legacy_get_predict_ops(features) predictions = self._filter_predictions(infer_ops.predictions, outputs) mon_sess = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path)) if not as_iterable: with mon_sess: if not mon_sess.should_stop(): return mon_sess.run(predictions, feed_fn() if feed_fn else None) else: return self._predict_generator(mon_sess, predictions, feed_fn, iterate_batches)
Example #22
Source File: checkpoint_utils.py From keras-lambda with MIT License | 5 votes |
def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): return saver.latest_checkpoint(filepattern) return filepattern
Example #23
Source File: estimator.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _infer_model(self, input_fn, feed_fn=None, outputs=None, as_iterable=True, iterate_batches=False): # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError("Couldn't find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) contrib_framework.create_global_step(g) features = self._get_features_from_input_fn(input_fn) infer_ops = self._call_legacy_get_predict_ops(features) predictions = self._filter_predictions(infer_ops.predictions, outputs) mon_sess = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path)) if not as_iterable: with mon_sess: if not mon_sess.should_stop(): return mon_sess.run(predictions, feed_fn() if feed_fn else None) else: return self._predict_generator(mon_sess, predictions, feed_fn, iterate_batches)
Example #24
Source File: estimator.py From lambda-packs with MIT License | 5 votes |
def _load_global_step_from_checkpoint_dir(checkpoint_dir): try: checkpoint_reader = training.NewCheckpointReader( training.latest_checkpoint(checkpoint_dir)) return checkpoint_reader.get_tensor(ops.GraphKeys.GLOBAL_STEP) except: # pylint: disable=bare-except return 0
Example #25
Source File: checkpoint_utils.py From lambda-packs with MIT License | 5 votes |
def _get_checkpoint_filename(ckpt_dir_or_file): """Returns checkpoint filename given directory or specific checkpoint file.""" if gfile.IsDirectory(ckpt_dir_or_file): return saver.latest_checkpoint(ckpt_dir_or_file) return ckpt_dir_or_file
Example #26
Source File: projector_plugin.py From lambda-packs with MIT License | 5 votes |
def _find_latest_checkpoint(dir_path): try: ckpt_path = latest_checkpoint(dir_path) if not ckpt_path: # Check the parent directory. ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir)) return ckpt_path except errors.NotFoundError: return None
Example #27
Source File: estimator.py From lambda-packs with MIT License | 5 votes |
def _infer_model(self, input_fn, feed_fn=None, outputs=None, as_iterable=True, iterate_batches=False): # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise NotFittedError("Couldn't find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) contrib_framework.create_global_step(g) features = self._get_features_from_input_fn(input_fn) infer_ops = self._get_predict_ops(features) predictions = self._filter_predictions(infer_ops.predictions, outputs) mon_sess = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=infer_ops.scaffold, config=self._session_config)) if not as_iterable: with mon_sess: if not mon_sess.should_stop(): return mon_sess.run(predictions, feed_fn() if feed_fn else None) else: return self._predict_generator(mon_sess, predictions, feed_fn, iterate_batches)
Example #28
Source File: evaluation.py From lambda-packs with MIT License | 5 votes |
def wait_for_new_checkpoint(checkpoint_dir, last_checkpoint=None, seconds_to_sleep=1, timeout=None): """Waits until a new checkpoint file is found. Args: checkpoint_dir: The directory in which checkpoints are saved. last_checkpoint: The last checkpoint path used or `None` if we're expecting a checkpoint for the first time. seconds_to_sleep: The number of seconds to sleep for before looking for a new checkpoint. timeout: The maximum amount of time to wait. If left as `None`, then the process will wait indefinitely. Returns: a new checkpoint path, or None if the timeout was reached. """ logging.info('Waiting for new checkpoint at %s', checkpoint_dir) stop_time = time.time() + timeout if timeout is not None else None while True: checkpoint_path = tf_saver.latest_checkpoint(checkpoint_dir) if checkpoint_path is None or checkpoint_path == last_checkpoint: if stop_time is not None and time.time() + seconds_to_sleep > stop_time: return None time.sleep(seconds_to_sleep) else: logging.info('Found new checkpoint at %s', checkpoint_path) return checkpoint_path
Example #29
Source File: checkpoint_utils.py From lambda-packs with MIT License | 5 votes |
def _get_checkpoint_filename(filepattern): """Returns checkpoint filename given directory or specific filepattern.""" if gfile.IsDirectory(filepattern): return saver.latest_checkpoint(filepattern) return filepattern
Example #30
Source File: plugin.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _read_latest_config_files(self, run_path_pairs): """Reads and returns the projector config files in every run directory.""" configs = {} config_fpaths = {} for run_name, logdir in run_path_pairs: config = ProjectorConfig() config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) if file_io.file_exists(config_fpath): file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') text_format.Merge(file_content, config) has_tensor_files = False for embedding in config.embeddings: if embedding.tensor_path: has_tensor_files = True break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. ckpt_path = latest_checkpoint(logdir) if not ckpt_path: # Or in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join(logdir, os.pardir)) if not ckpt_path and not has_tensor_files: continue if ckpt_path: config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. if (config.model_checkpoint_path and not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file %s not found', config.model_checkpoint_path) continue configs[run_name] = config config_fpaths[run_name] = config_fpath return configs, config_fpaths