Python tensorflow.python.framework.graph_util.extract_sub_graph() Examples

The following are 14 code examples of tensorflow.python.framework.graph_util.extract_sub_graph(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.framework.graph_util , or try the search function .
Example #1
Source File: strip_unused_lib.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
Example #2
Source File: strip_unused_lib.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(
            node.attr["_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
Example #3
Source File: graph_util_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def testExtractSubGraph(self):
    graph_def = tf.GraphDef()
    n1 = graph_def.node.add()
    n1.name = "n1"
    n1.input.extend(["n5"])
    n2 = graph_def.node.add()
    n2.name = "n2"
    # Take the first output of the n1 node as the input.
    n2.input.extend(["n1:0"])
    n3 = graph_def.node.add()
    n3.name = "n3"
    # Add a control input (which isn't really needed by the kernel, but
    # rather to enforce execution order between nodes).
    n3.input.extend(["^n2"])
    n4 = graph_def.node.add()
    n4.name = "n4"

    # It is fine to have a loops in the graph as well.
    n5 = graph_def.node.add()
    n5.name = "n5"
    n5.input.extend(["n1"])

    sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
    self.assertEqual("n1", sub_graph.node[0].name)
    self.assertEqual("n2", sub_graph.node[1].name)
    self.assertEqual("n3", sub_graph.node[2].name)
    self.assertEqual("n5", sub_graph.node[3].name) 
Example #4
Source File: quantize_graph.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
Example #5
Source File: quantize_graph.py    From tensorflow-for-poets-2 with Apache License 2.0 5 votes vote down vote up
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
Example #6
Source File: quantize_graph.py    From MobileNet with Apache License 2.0 5 votes vote down vote up
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
Example #7
Source File: quantize_graph.py    From sketch-to-react-native with MIT License 5 votes vote down vote up
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
Example #8
Source File: graph_xform.py    From tensorlang with Apache License 2.0 5 votes vote down vote up
def strip_meta_graph(meta_graph_def, node_names, var_names):
  node_names = node_names[:]
  collections = meta_graph_def.collection_def

  # Look for matching variable names and initializers and keep them too.
  var_def = variable_pb2.VariableDef()
  for var_col_name in ["variables", "trainable_variables"]:
    var_def_bs = collections[var_col_name].bytes_list.value
    for var_def_b in var_def_bs:
      var_def.ParseFromString(var_def_b)
      if var_def.variable_name not in var_names:
        # TODO(adamb) Should remove variable from collection.
        continue
      node_names.append(var_def.initializer_name)

  wc_def = control_flow_pb2.WhileContextDef()
  wc_values = collections["while_context"].bytes_list.value
  for wc_ix in range(len(wc_values) - 1, -1, -1):
    wc_bytes = wc_values[wc_ix]
    wc_def.ParseFromString(wc_bytes)
    unused = True
    wc_pivot_name = wc_def.pivot_name
    for name in node_names:
      if name.startswith(wc_pivot_name):
        unused = False
        break

    if unused:
      del wc_values[wc_ix]

  graph_def = meta_graph_def.graph_def
  eprint("only keeping", node_names, "from", [n.name for n in graph_def.node])
  graph_def = graph_util.extract_sub_graph(graph_def, node_names)
  meta_graph_def.graph_def.CopyFrom(graph_def) 
Example #9
Source File: quantize_graph.py    From pokemon-mini with Apache License 2.0 5 votes vote down vote up
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
Example #10
Source File: quantize_graph.py    From AudioNet with MIT License 5 votes vote down vote up
def remove_dead_nodes(self, output_names):
    """Removes nodes that are no longer needed for inference from the graph."""
    old_output_graph = self.output_graph
    self.output_graph = graph_util.extract_sub_graph(old_output_graph,
                                                     output_names) 
Example #11
Source File: strip_unused_lib.py    From keras-lambda with MIT License 5 votes vote down vote up
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
Example #12
Source File: strip_unused_lib.py    From lambda-packs with MIT License 4 votes vote down vote up
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed.

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
  for name in input_node_names:
    if ":" in name:
      raise ValueError("Name '%s' appears to refer to a Tensor, "
                       "not a Operation." % name)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  not_found = {name for name in input_node_names}
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      not_found.remove(node.name)
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  if not_found:
    raise KeyError("The following input nodes were not found: %s\n" % not_found)

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
Example #13
Source File: strip_unused.py    From TensorFlow_DCIGN with MIT License 4 votes vote down vote up
def strip_unused(input_graph, input_binary, output_graph, input_node_names,
                 output_node_names, placeholder_type_enum):
  """Removes unused nodes from a graph."""

  if not tf.gfile.Exists(input_graph):
    print("Input graph file '" + input_graph + "' does not exist!")
    return -1

  if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

  input_graph_def = tf.GraphDef()
  mode = "rb" if input_binary else "r"
  with tf.gfile.FastGFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  input_node_names_list = input_node_names.split(",")
  inputs_replaced_graph_def = tf.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names_list:
      placeholder_node = tf.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      placeholder_node.attr["dtype"].CopyFrom(tf.AttrValue(
          type=placeholder_type_enum))
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names.split(","))

  with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
  print("%d ops in the final graph." % len(output_graph_def.node)) 
Example #14
Source File: strip_unused_lib.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 4 votes vote down vote up
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A `GraphDef` with all unnecessary ops removed.

  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
  for name in input_node_names:
    if ":" in name:
      raise ValueError("Name '%s' appears to refer to a Tensor, "
                       "not a Operation." % name)

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  not_found = {name for name in input_node_names}
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      not_found.remove(node.name)
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  if not_found:
    raise KeyError("The following input nodes were not found: %s\n" % not_found)

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def