Python tensorflow.python.training.saver.import_meta_graph() Examples

The following are 1 code examples of tensorflow.python.training.saver.import_meta_graph(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.training.saver , or try the search function .
Example #1
Source File: loader.py    From deep_image_model with Apache License 2.0 4 votes vote down vote up
def load(sess, tags, export_dir):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  # Build the SavedModel protocol buffer and find the requested meta graph def.
  saved_model = _parse_saved_model(export_dir)
  found_match = False
  for meta_graph_def in saved_model.meta_graphs:
    if set(meta_graph_def.meta_info_def.tags) == set(tags):
      meta_graph_def_to_load = meta_graph_def
      found_match = True
      break

  if not found_match:
    raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
        "[]") + " could not be found in SavedModel")

  # Build a saver by importing the meta graph def to load.
  saver = tf_saver.import_meta_graph(meta_graph_def_to_load)

  # Build the checkpoint path where the variables are located.
  variables_path = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.VARIABLES_DIRECTORY),
      compat.as_bytes(constants.VARIABLES_FILENAME))

  # Restore the variables using the built saver in the provided session.
  saver.restore(sess, variables_path)

  # Get asset tensors, if any.
  asset_tensors_dictionary = _get_asset_tensors(export_dir,
                                                meta_graph_def_to_load)

  # TODO(sukritiramesh): Add support for a single main op to run upon load,
  # which will supersede the legacy_init_op.
  legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)

  if legacy_init_op_tensor is not None:
    sess.run(fetches=[legacy_init_op_tensor],
             feed_dict=asset_tensors_dictionary)

  return meta_graph_def_to_load