Python tensorflow.python.training.checkpoint_state_pb2.CheckpointState() Examples
The following are 27
code examples of tensorflow.python.training.checkpoint_state_pb2.CheckpointState().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.training.checkpoint_state_pb2
, or try the search function
.
Example #1
Source File: utils.py From deepchem with MIT License | 6 votes |
def ParseCheckpoint(checkpoint): """Parse a checkpoint file. Args: checkpoint: Path to checkpoint. The checkpoint is either a serialized CheckpointState proto or an actual checkpoint file. Returns: The path to an actual checkpoint file. """ warnings.warn( "ParseCheckpoint is deprecated. " "Will be removed in DeepChem 1.4.", DeprecationWarning) with open(checkpoint) as f: try: cp = checkpoint_state_pb2.CheckpointState() text_format.Merge(f.read(), cp) return cp.model_checkpoint_path except text_format.ParseError: return checkpoint
Example #2
Source File: saver.py From keras-lambda with MIT License | 5 votes |
def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. Args: save_dir: The directory for saving and restoring checkpoints. latest_filename: Name of the file in 'save_dir' that is used to store the CheckpointState. Returns: The path of the file that contains the CheckpointState proto. """ if latest_filename is None: latest_filename = "checkpoint" return os.path.join(save_dir, latest_filename)
Example #3
Source File: saver.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. Args: save_dir: The directory for saving and restoring checkpoints. latest_filename: Name of the file in 'save_dir' that is used to store the CheckpointState. Returns: The path of the file that contains the CheckpointState proto. """ if latest_filename is None: latest_filename = "checkpoint" return os.path.join(save_dir, latest_filename)
Example #4
Source File: estimator_test.py From estimator with Apache License 2.0 | 5 votes |
def test_checkpoint_contains_relative_paths(self): tmpdir = tempfile.mkdtemp() est = estimator.EstimatorV2( model_dir=tmpdir, model_fn=model_fn_global_step_incrementer) est.train(dummy_input_fn, steps=5) checkpoint_file_content = file_io.read_file_to_string( os.path.join(tmpdir, 'checkpoint')) ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') # TODO(b/78461127): Please modify tests to not directly rely on names of # checkpoints. self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
Example #5
Source File: s3_boto_data_store.py From aws-builders-fair-projects with Apache License 2.0 | 5 votes |
def _get_current_checkpoint(self): try: checkpoint_metadata_filepath = os.path.abspath( os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME)) checkpoint = CheckpointState() if os.path.exists(checkpoint_metadata_filepath) == False: return None contents = open(checkpoint_metadata_filepath, 'r').read() text_format.Merge(contents, checkpoint) return checkpoint except Exception as e: print("Got exception while reading checkpoint metadata", e) raise e
Example #6
Source File: saver.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 5 votes |
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ _update_checkpoint_state( save_dir=save_dir, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths, latest_filename=latest_filename, save_relative_paths=False)
Example #7
Source File: saver.py From deep_image_model with Apache License 2.0 | 5 votes |
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If the save paths conflict. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
Example #8
Source File: saver.py From deep_image_model with Apache License 2.0 | 5 votes |
def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. Args: save_dir: The directory for saving and restoring checkpoints. latest_filename: Name of the file in 'save_dir' that is used to store the CheckpointState. Returns: The path of the file that contains the CheckpointState proto. """ if latest_filename is None: latest_filename = "checkpoint" return os.path.join(save_dir, latest_filename)
Example #9
Source File: saver.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If the save paths conflict. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
Example #10
Source File: saver.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. Args: save_dir: The directory for saving and restoring checkpoints. latest_filename: Name of the file in 'save_dir' that is used to store the CheckpointState. Returns: The path of the file that contains the CheckpointState proto. """ if latest_filename is None: latest_filename = "checkpoint" return os.path.join(save_dir, latest_filename)
Example #11
Source File: saver.py From keras-lambda with MIT License | 5 votes |
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If the save paths conflict. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
Example #12
Source File: saver.py From lambda-packs with MIT License | 5 votes |
def update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ _update_checkpoint_state( save_dir=save_dir, model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths, latest_filename=latest_filename, save_relative_paths=False)
Example #13
Source File: saver.py From lambda-packs with MIT License | 5 votes |
def _GetCheckpointFilename(save_dir, latest_filename): """Returns a filename for storing the CheckpointState. Args: save_dir: The directory for saving and restoring checkpoints. latest_filename: Name of the file in 'save_dir' that is used to store the CheckpointState. Returns: The path of the file that contains the CheckpointState proto. """ if latest_filename is None: latest_filename = "checkpoint" return os.path.join(save_dir, latest_filename)
Example #14
Source File: saver.py From lingvo with Apache License 2.0 | 5 votes |
def _GetState(self): """Returns the latest checkpoint id.""" state = CheckpointState() if file_io.file_exists(self._state_file): content = file_io.read_file_to_string(self._state_file) text_format.Merge(content, state) return state
Example #15
Source File: test_utils.py From deepchem with MIT License | 5 votes |
def testParseCheckpoint(self): # parse CheckpointState proto with tempfile.NamedTemporaryFile(mode='w+') as f: cp = checkpoint_state_pb2.CheckpointState() cp.model_checkpoint_path = 'my-checkpoint' f.write(text_format.MessageToString(cp)) f.file.flush() self.assertEqual(utils.ParseCheckpoint(f.name), 'my-checkpoint') # parse path to actual checkpoint with tempfile.NamedTemporaryFile(mode='w+') as f: f.write('This is not a CheckpointState proto.') f.file.flush() self.assertEqual(utils.ParseCheckpoint(f.name), f.name)
Example #16
Source File: saver.py From keras-lambda with MIT License | 4 votes |
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string( coord_checkpoint_filename).decode("utf-8") ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: raise ValueError("Invalid checkpoint state loaded from %s", checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
Example #17
Source File: saver.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 4 votes |
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string( coord_checkpoint_filename) ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: raise ValueError("Invalid checkpoint state loaded from %s", checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
Example #18
Source File: saver.py From keras-lambda with MIT License | 4 votes |
def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if (not all_model_checkpoint_paths or all_model_checkpoint_paths[-1] != model_checkpoint_path): logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) for i in range(len(all_model_checkpoint_paths)): p = all_model_checkpoint_paths[i] if not os.path.isabs(p): all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) return coord_checkpoint_proto
Example #19
Source File: saver.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 4 votes |
def _update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None, save_relative_paths=False): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) if save_relative_paths: if os.path.isabs(model_checkpoint_path): rel_model_checkpoint_path = os.path.relpath( model_checkpoint_path, save_dir) else: rel_model_checkpoint_path = model_checkpoint_path rel_all_model_checkpoint_paths = [] for p in all_model_checkpoint_paths: if os.path.isabs(p): rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) else: rel_all_model_checkpoint_paths.append(p) ckpt = generate_checkpoint_state_proto( save_dir, rel_model_checkpoint_path, all_model_checkpoint_paths=rel_all_model_checkpoint_paths) else: ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
Example #20
Source File: saver.py From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License | 4 votes |
def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if (not all_model_checkpoint_paths or all_model_checkpoint_paths[-1] != model_checkpoint_path): logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) for i in range(len(all_model_checkpoint_paths)): p = all_model_checkpoint_paths[i] if not os.path.isabs(p): all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) return coord_checkpoint_proto
Example #21
Source File: saver.py From deep_image_model with Apache License 2.0 | 4 votes |
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string( coord_checkpoint_filename).decode("utf-8") ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: raise ValueError("Invalid checkpoint state loaded from %s", checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
Example #22
Source File: saver.py From deep_image_model with Apache License 2.0 | 4 votes |
def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if (not all_model_checkpoint_paths or all_model_checkpoint_paths[-1] != model_checkpoint_path): logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) for i in range(len(all_model_checkpoint_paths)): p = all_model_checkpoint_paths[i] if not os.path.isabs(p): all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) return coord_checkpoint_proto
Example #23
Source File: saver.py From auto-alt-text-lambda-api with MIT License | 4 votes |
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string( coord_checkpoint_filename).decode("utf-8") ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: raise ValueError("Invalid checkpoint state loaded from %s", checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning(str(e)) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
Example #24
Source File: saver.py From auto-alt-text-lambda-api with MIT License | 4 votes |
def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if (not all_model_checkpoint_paths or all_model_checkpoint_paths[-1] != model_checkpoint_path): logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) for i in range(len(all_model_checkpoint_paths)): p = all_model_checkpoint_paths[i] if not os.path.isabs(p): all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) return coord_checkpoint_proto
Example #25
Source File: saver.py From lambda-packs with MIT License | 4 votes |
def get_checkpoint_state(checkpoint_dir, latest_filename=None): """Returns CheckpointState proto from the "checkpoint" file. If the "checkpoint" file contains a valid CheckpointState proto, returns it. Args: checkpoint_dir: The directory of checkpoints. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. Returns: A CheckpointState if the state was available, None otherwise. Raises: ValueError: if the checkpoint read doesn't have model_checkpoint_path set. """ ckpt = None coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename) f = None try: # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if file_io.file_exists(coord_checkpoint_filename): file_content = file_io.read_file_to_string( coord_checkpoint_filename) ckpt = CheckpointState() text_format.Merge(file_content, ckpt) if not ckpt.model_checkpoint_path: raise ValueError("Invalid checkpoint state loaded from %s", checkpoint_dir) # For relative model_checkpoint_path and all_model_checkpoint_paths, # prepend checkpoint_dir. if not os.path.isabs(ckpt.model_checkpoint_path): ckpt.model_checkpoint_path = os.path.join(checkpoint_dir, ckpt.model_checkpoint_path) for i in range(len(ckpt.all_model_checkpoint_paths)): p = ckpt.all_model_checkpoint_paths[i] if not os.path.isabs(p): ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p) except errors.OpError as e: # It's ok if the file cannot be read logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None except text_format.ParseError as e: logging.warning("%s: %s", type(e).__name__, e) logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename) return None finally: if f: f.close() return ckpt
Example #26
Source File: saver.py From lambda-packs with MIT License | 4 votes |
def _update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None, save_relative_paths=False): """Updates the content of the 'checkpoint' file. This updates the checkpoint file containing a CheckpointState proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. latest_filename: Optional name of the checkpoint file. Default to 'checkpoint'. save_relative_paths: If `True`, will write relative paths to the checkpoint state file. Raises: RuntimeError: If any of the model checkpoint paths conflict with the file containing CheckpointSate. """ # Writes the "checkpoint" file for the coordinator for later restoration. coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename) if save_relative_paths: if os.path.isabs(model_checkpoint_path): rel_model_checkpoint_path = os.path.relpath( model_checkpoint_path, save_dir) else: rel_model_checkpoint_path = model_checkpoint_path rel_all_model_checkpoint_paths = [] for p in all_model_checkpoint_paths: if os.path.isabs(p): rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir)) else: rel_all_model_checkpoint_paths.append(p) ckpt = generate_checkpoint_state_proto( save_dir, rel_model_checkpoint_path, all_model_checkpoint_paths=rel_all_model_checkpoint_paths) else: ckpt = generate_checkpoint_state_proto( save_dir, model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) if coord_checkpoint_filename == ckpt.model_checkpoint_path: raise RuntimeError("Save path '%s' conflicts with path used for " "checkpoint state. Please use a different save path." % model_checkpoint_path) # Preventing potential read/write race condition by *atomically* writing to a # file. file_io.atomic_write_string_to_file(coord_checkpoint_filename, text_format.MessageToString(ckpt))
Example #27
Source File: saver.py From lambda-packs with MIT License | 4 votes |
def generate_checkpoint_state_proto(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None): """Generates a checkpoint state proto. Args: save_dir: Directory where the model was saved. model_checkpoint_path: The checkpoint file. all_model_checkpoint_paths: List of strings. Paths to all not-yet-deleted checkpoints, sorted from oldest to newest. If this is a non-empty list, the last element must be equal to model_checkpoint_path. These paths are also saved in the CheckpointState proto. Returns: CheckpointState proto with model_checkpoint_path and all_model_checkpoint_paths updated to either absolute paths or relative paths to the current save_dir. """ if all_model_checkpoint_paths is None: all_model_checkpoint_paths = [] if (not all_model_checkpoint_paths or all_model_checkpoint_paths[-1] != model_checkpoint_path): logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.", model_checkpoint_path) all_model_checkpoint_paths.append(model_checkpoint_path) # Relative paths need to be rewritten to be relative to the "save_dir" # if model_checkpoint_path already contains "save_dir". if not os.path.isabs(save_dir): if not os.path.isabs(model_checkpoint_path): model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir) for i in range(len(all_model_checkpoint_paths)): p = all_model_checkpoint_paths[i] if not os.path.isabs(p): all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir) coord_checkpoint_proto = CheckpointState( model_checkpoint_path=model_checkpoint_path, all_model_checkpoint_paths=all_model_checkpoint_paths) return coord_checkpoint_proto