Python tensorflow.core.framework.node_def_pb2.NodeDef() Examples

The following are 30 code examples of tensorflow.core.framework.node_def_pb2.NodeDef(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.core.framework.node_def_pb2 , or try the search function .
Example #1
Source File: pb_wrapper.py    From onnx-tensorflow with Apache License 2.0 6 votes vote down vote up
def __init__(self,
               node=None,
               name=None,
               inputs=None,
               outputs=None,
               attr=None,
               domain=None,
               op_type=None):
    # storing a reference to the original protobuf object
    if node is None:
      self.node = None
      self.name = name or ""
      self.inputs = inputs or []
      self.attr = attr or {}
      self.domain = domain or ""
      self.op_type = op_type or ""
      self.outputs = outputs or self.get_outputs_names()
    elif isinstance(node, (OnnxNode, NodeProto)):
      self._load_onnx_node(node)
    elif isinstance(node, NodeDef):
      self._load_tf_node(node) 
Example #2
Source File: optimize_for_inference_lib.py    From lambda-packs with MIT License 6 votes vote down vote up
def values_from_const(node_def):
  """Extracts the values from a const NodeDef as a numpy ndarray.

  Args:
    node_def: Const NodeDef that has the values we want to access.

  Returns:
    Numpy ndarray containing the values.

  Raises:
    ValueError: If the node isn't a Const.
  """
  if node_def.op != "Const":
    raise ValueError(
        "Node named '%s' should be a Const op for values_from_const." %
        node_def.name)
  input_tensor = node_def.attr["value"].tensor
  tensor_value = tensor_util.MakeNdarray(input_tensor)
  return tensor_value 
Example #3
Source File: optimize_for_inference_lib.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def node_from_map(node_map, name):
  """Pulls a node def from a dictionary for a given name.

  Args:
    node_map: Dictionary containing an entry indexed by name for every node.
    name: Identifies the node we want to find.

  Returns:
    NodeDef of the node with the given name.

  Raises:
    ValueError: If the node isn't present in the dictionary.
  """
  stripped_name = node_name_from_input(name)
  if stripped_name not in node_map:
    raise ValueError("No node named '%s' found in map." % name)
  return node_map[stripped_name] 
Example #4
Source File: ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def _as_node_def_input(self):
    """Return a value to use for the NodeDef "input" attribute.

    The returned string can be used in a NodeDef "input" attribute
    to indicate that the NodeDef uses this Tensor as input.

    Raises:
      ValueError: if this Tensor's Operation does not have a name.

    Returns:
      a string.
    """
    if not self._op.name:
      raise ValueError("Operation was not named: %s" % self._op)
    if self._value_index == 0:
      return self._op.name
    else:
      return "%s:%d" % (self._op.name, self._value_index) 
Example #5
Source File: optimize_for_inference_lib.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def values_from_const(node_def):
  """Extracts the values from a const NodeDef as a numpy ndarray.

  Args:
    node_def: Const NodeDef that has the values we want to access.

  Returns:
    Numpy ndarray containing the values.

  Raises:
    ValueError: If the node isn't a Const.
  """
  if node_def.op != "Const":
    raise ValueError(
        "Node named '%s' should be a Const op for values_from_const." %
        node_def.name)
  input_tensor = node_def.attr["value"].tensor
  tensor_value = tensor_util.MakeNdarray(input_tensor)
  return tensor_value 
Example #6
Source File: optimize_for_inference_lib.py    From lambda-packs with MIT License 6 votes vote down vote up
def node_from_map(node_map, name):
  """Pulls a node def from a dictionary for a given name.

  Args:
    node_map: Dictionary containing an entry indexed by name for every node.
    name: Identifies the node we want to find.

  Returns:
    NodeDef of the node with the given name.

  Raises:
    ValueError: If the node isn't present in the dictionary.
  """
  stripped_name = node_name_from_input(name)
  if stripped_name not in node_map:
    raise ValueError("No node named '%s' found in map." % name)
  return node_map[stripped_name] 
Example #7
Source File: ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _as_node_def_input(self):
    """Return a value to use for the NodeDef "input" attribute.

    The returned string can be used in a NodeDef "input" attribute
    to indicate that the NodeDef uses this Tensor as input.

    Raises:
      ValueError: if this Tensor's Operation does not have a name.

    Returns:
      a string.
    """
    if not self._op.name:
      raise ValueError("Operation was not named: %s" % self._op)
    if self._value_index == 0:
      return self._op.name
    else:
      return "%s:%d" % (self._op.name, self._value_index) 
Example #8
Source File: lib.py    From parallax with Apache License 2.0 6 votes vote down vote up
def _len_node_outputs(self, node_def):
        assert isinstance(node_def, node_def_pb2.NodeDef), 'node_def type %s' % type(node_def)
        op_def = self._op_defs[node_def.op]
        len_outputs = 0
        for output_argdef in op_def.output_arg:
            if output_argdef.number_attr:
                # A sequence of tensors with the same type
                len_outputs += node_def.attr[output_argdef.number_attr].i
            elif output_argdef.type_list_attr:
                # A sequence of tensors
                len_outputs += len(node_def.attr[output_argdef.type_list_attr].list.type)
            else:
                # A single tensor
                len_outputs += 1

        return len_outputs 
Example #9
Source File: lib.py    From parallax with Apache License 2.0 6 votes vote down vote up
def extend_mapping_from_nodedef(self, single_nodedef, replica_nodedef):
        assert isinstance(single_nodedef, node_def_pb2.NodeDef), \
                'single nodedef type is %s' % type(single_nodedef)
        assert isinstance(replica_nodedef, node_def_pb2.NodeDef), \
                'replica nodedef type is %s' % type(replica_nodedef)

        def _append_mapping(tensor_or_op_name, replica_name):
            if tensor_or_op_name not in self._mapping:
                self._mapping[tensor_or_op_name] = []
            assert isinstance(self._mapping[tensor_or_op_name], list)
            self._mapping[tensor_or_op_name].append(replica_name)

        _append_mapping(single_nodedef.name, replica_nodedef.name)

        for i in range(self._len_node_outputs(single_nodedef)):
            single_tensor_name = '%s:%d' % (single_nodedef.name, i)
            replica_tensor_name = '%s:%d' % (replica_nodedef.name, i)
            _append_mapping(single_tensor_name, replica_tensor_name) 
Example #10
Source File: tf_replicate_model_fn.py    From tf.fashionAI with Apache License 2.0 6 votes vote down vote up
def _local_device_setter(worker_device, ps_devices, ps_strategy):
  """A device setter that puts distributes Var/Ops to PS/workers."""
  ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

  def local_device_chooser(op):
    current_device = framework_device.DeviceSpec.from_string(op.device or '')

    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if node_def.op in ps_ops:
      ps_device_spec = framework_device.DeviceSpec.from_string(
          '{}'.format(ps_devices[ps_strategy(op)]))

      ps_device_spec.merge_from(current_device)
      return ps_device_spec.to_string()
    else:
      worker_device_spec = framework_device.DeviceSpec.from_string(
          worker_device or '')
      worker_device_spec.merge_from(current_device)
      return worker_device_spec.to_string()

  return local_device_chooser 
Example #11
Source File: ops.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def get_stats_for_node_def(graph, node, statistic_type):
  """Looks up the node's statistics function in the registry and calls it.

  This function takes a Graph object and a NodeDef from a GraphDef, and if
  there's an associated statistics method, calls it and returns a result. If no
  function has been registered for the particular node type, it returns an empty
  statistics object.

  Args:
    graph: A Graph object that's been set up with the node's graph.
    node: A NodeDef describing the operator.
    statistic_type: A string identifying the statistic we're interested in.
  Returns:
    An OpStats object containing information about resource usage.
  """

  try:
    stats_func = _stats_registry.lookup(node.op + "," + statistic_type)
    result = stats_func(graph, node)
  except LookupError:
    result = OpStats(statistic_type)
  return result 
Example #12
Source File: tensorflow_graph.py    From MMdnn with MIT License 6 votes vote down vote up
def build(self):
        for i, layer in enumerate(self.model.node):
            self.layer_map[layer.name] = TensorflowGraphNode(layer)
            self.layer_name_map[layer.name] = layer.name
            for pred in layer.input:
                if pred not in self.layer_map:
                    if not pred.split(':')[0] in self.layer_map: #test
                        new_node = NodeDef()
                        new_node.name = pred
                        new_node.op = "NoOp"
                        self.layer_map[pred] = TensorflowGraphNode(new_node)
                        self.layer_name_map[pred] = pred

                self.tf_make_connection(pred, layer.name)

        super(TensorflowGraph, self).build() 
Example #13
Source File: quantize_graph.py    From tensorflow-for-poets-2 with Apache License 2.0 6 votes vote down vote up
def quantize_nodes_recursively(self, current_node):
    """The entry point for quantizing nodes to eight bit and back."""
    if self.already_visited[current_node.name]:
      return
    self.already_visited[current_node.name] = True
    for input_node_name in current_node.input:
      input_node_name = node_name_from_input(input_node_name)
      input_node = self.nodes_map[input_node_name]
      self.quantize_nodes_recursively(input_node)
    nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
    if any(current_node.op in s for s in nodes_to_quantize):
      for input_name in current_node.input:
        input_name = node_name_from_input(input_name)
        input_node = self.nodes_map[input_name]
        self.quantize_node(input_node)
      self.quantize_node(current_node)
    else:
      new_node = node_def_pb2.NodeDef()
      new_node.CopyFrom(current_node)
      self.add_output_graph_node(new_node) 
Example #14
Source File: quantize_graph.py    From MobileNet with Apache License 2.0 6 votes vote down vote up
def quantize_nodes_recursively(self, current_node):
    """The entry point for quantizing nodes to eight bit and back."""
    if self.already_visited[current_node.name]:
      return
    self.already_visited[current_node.name] = True
    for input_node_name in current_node.input:
      input_node_name = node_name_from_input(input_node_name)
      input_node = self.nodes_map[input_node_name]
      self.quantize_nodes_recursively(input_node)
    nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
    if any(current_node.op in s for s in nodes_to_quantize):
      for input_name in current_node.input:
        input_name = node_name_from_input(input_name)
        input_node = self.nodes_map[input_name]
        self.quantize_node(input_node)
      self.quantize_node(current_node)
    else:
      new_node = node_def_pb2.NodeDef()
      new_node.CopyFrom(current_node)
      self.add_output_graph_node(new_node) 
Example #15
Source File: graph_rewrite_util.py    From tfjs-to-tf with MIT License 6 votes vote down vote up
def is_fused_op(node: NodeDef, op_name: Text, activation: Text) -> bool:
    """
    Return whether a node represents a fused TF operation.

    Args:
        node: Node defintion
        op_name: Fused operation name (e.g. 'MatMul')
        activation: Name of the fused activation function (e.g. 'Relu')

    Returns:
        `True`, iff the node is a fused operation with the given activation
    """
    if node.op == f'_Fused{op_name}' and 'fused_ops' in node.attr:
        fused_ops = node.attr['fused_ops'].list.s
        return (len(fused_ops) == 2
                and fused_ops[0] in (b'BiasAdd', b'BiasAddV1')
                and fused_ops[1] == activation)
    return False 
Example #16
Source File: quantize_graph.py    From sketch-to-react-native with MIT License 6 votes vote down vote up
def quantize_nodes_recursively(self, current_node):
    """The entry point for quantizing nodes to eight bit and back."""
    if self.already_visited[current_node.name]:
      return
    self.already_visited[current_node.name] = True
    for input_node_name in current_node.input:
      input_node_name = node_name_from_input(input_node_name)
      input_node = self.nodes_map[input_node_name]
      self.quantize_nodes_recursively(input_node)
    nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
    if any(current_node.op in s for s in nodes_to_quantize):
      for input_name in current_node.input:
        input_name = node_name_from_input(input_name)
        input_node = self.nodes_map[input_name]
        self.quantize_node(input_node)
      self.quantize_node(current_node)
    else:
      new_node = node_def_pb2.NodeDef()
      new_node.CopyFrom(current_node)
      self.add_output_graph_node(new_node) 
Example #17
Source File: quantize_graph.py    From pokemon-mini with Apache License 2.0 6 votes vote down vote up
def quantize_nodes_recursively(self, current_node):
    """The entry point for quantizing nodes to eight bit and back."""
    if self.already_visited[current_node.name]:
      return
    self.already_visited[current_node.name] = True
    for input_node_name in current_node.input:
      input_node_name = node_name_from_input(input_node_name)
      input_node = self.nodes_map[input_node_name]
      self.quantize_nodes_recursively(input_node)
    nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
    if any(current_node.op in s for s in nodes_to_quantize):
      for input_name in current_node.input:
        input_name = node_name_from_input(input_name)
        input_node = self.nodes_map[input_name]
        self.quantize_node(input_node)
      self.quantize_node(current_node)
    else:
      new_node = node_def_pb2.NodeDef()
      new_node.CopyFrom(current_node)
      self.add_output_graph_node(new_node) 
Example #18
Source File: imagenet_utils.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
  if ps_ops == None:
    ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

  if ps_strategy is None:
    ps_strategy = device_setter._RoundRobinStrategy(num_devices)
  if not six.callable(ps_strategy):
    raise TypeError("ps_strategy must be callable")

  def _local_device_chooser(op):
    current_device = pydev.DeviceSpec.from_string(op.device or "")

    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if node_def.op in ps_ops:
      ps_device_spec = pydev.DeviceSpec.from_string(
          '/{}:{}'.format(ps_device_type, ps_strategy(op)))

      ps_device_spec.merge_from(current_device)
      return ps_device_spec.to_string()
    else:
      worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
      worker_device_spec.merge_from(current_device)
      return worker_device_spec.to_string()
  return _local_device_chooser 
Example #19
Source File: distgpu_train.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
    if ps_ops == None:
        ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

    if ps_strategy is None:
        ps_strategy = device_setter._RoundRobinStrategy(num_devices)
    if not six.callable(ps_strategy):
        raise TypeError("ps_strategy must be callable")

    def _local_device_chooser(op):
        current_device = pydev.DeviceSpec.from_string(op.device or "")

        node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
        if node_def.op in ps_ops:
            ps_device_spec = pydev.DeviceSpec.from_string(
                  '/{}:{}'.format(ps_device_type, ps_strategy(op)))

            ps_device_spec.merge_from(current_device)
            return ps_device_spec.to_string()
        else:
            worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
            worker_device_spec.merge_from(current_device)
            return worker_device_spec.to_string()
    return _local_device_chooser 
Example #20
Source File: quantize_graph.py    From MobileNet with Apache License 2.0 5 votes vote down vote up
def eightbitize_placeholder_node(self, current_node):
    """Replaces a placeholder node with a quint8 placeholder node+dequantize."""
    name = current_node.name

    # Convert the placeholder into a quantized type.
    output_node = node_def_pb2.NodeDef()
    output_node.CopyFrom(current_node)
    set_attr_dtype(output_node, "dtype", dtypes.quint8)
    output_node.name += "_original_input"
    self.add_output_graph_node(output_node)

    # Add a dequantize to convert back to float.
    dequantize_node = create_node("Dequantize", name, [
        output_node.name, "quantized_input_min_value",
        "quantized_input_max_value"
    ])
    set_attr_dtype(dequantize_node, "T", dtypes.quint8)
    set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
    self.add_output_graph_node(dequantize_node)

    # For the descent over the graph to work, the dequantize node must be named
    # current_node.name.  However, for the feeding of the graph to work, the
    # placeholder must have the name current_node.name; so record a final set
    # of renames to apply after all processing has been done.
    self.final_node_renames[output_node.name] = name
    self.final_node_renames[dequantize_node.name] = name + "_dequantize" 
Example #21
Source File: strip_unused_lib.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.

  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.

  Returns:
    A GraphDef with all unnecessary ops removed.
  """
  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def 
Example #22
Source File: quantize_graph.py    From MobileNet with Apache License 2.0 5 votes vote down vote up
def create_node(op, name, inputs):
  new_node = node_def_pb2.NodeDef()
  new_node.op = op
  new_node.name = name
  for input_name in inputs:
    new_node.input.extend([input_name])
  return new_node 
Example #23
Source File: tf_utils.py    From video_prediction with MIT License 5 votes vote down vote up
def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
    if ps_ops == None:
        ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

    if ps_strategy is None:
        ps_strategy = device_setter._RoundRobinStrategy(num_devices)
    if not six.callable(ps_strategy):
        raise TypeError("ps_strategy must be callable")

    def _local_device_chooser(op):
        current_device = pydev.DeviceSpec.from_string(op.device or "")

        node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
        if node_def.op in ps_ops:
            ps_device_spec = pydev.DeviceSpec.from_string(
                '/{}:{}'.format(ps_device_type, ps_strategy(op)))

            ps_device_spec.merge_from(current_device)
            return ps_device_spec.to_string()
        else:
            worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
            worker_device_spec.merge_from(current_device)
            return worker_device_spec.to_string()

    return _local_device_chooser 
Example #24
Source File: cifar10_utils.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
  if ps_ops == None:
    ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

  if ps_strategy is None:
    ps_strategy = device_setter._RoundRobinStrategy(num_devices)
  if not six.callable(ps_strategy):
    raise TypeError("ps_strategy must be callable")

  def _local_device_chooser(op):
    current_device = pydev.DeviceSpec.from_string(op.device or "")

    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if node_def.op in ps_ops:
      ps_device_spec = pydev.DeviceSpec.from_string(
          '/{}:{}'.format(ps_device_type, ps_strategy(op)))

      ps_device_spec.merge_from(current_device)
      return ps_device_spec.to_string()
    else:
      worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
      worker_device_spec.merge_from(current_device)
      return worker_device_spec.to_string()
  return _local_device_chooser 
Example #25
Source File: imagenet_utils.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
  if ps_ops == None:
    ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

  if ps_strategy is None:
    ps_strategy = device_setter._RoundRobinStrategy(num_devices)
  if not six.callable(ps_strategy):
    raise TypeError("ps_strategy must be callable")

  def _local_device_chooser(op):
    current_device = pydev.DeviceSpec.from_string(op.device or "")

    node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
    if node_def.op in ps_ops:
      ps_device_spec = pydev.DeviceSpec.from_string(
          '/{}:{}'.format(ps_device_type, ps_strategy(op)))

      ps_device_spec.merge_from(current_device)
      return ps_device_spec.to_string()
    else:
      worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
      worker_device_spec.merge_from(current_device)
      return worker_device_spec.to_string()
  return _local_device_chooser 
Example #26
Source File: tf_graph_util.py    From captcha_trainer with Apache License 2.0 5 votes vote down vote up
def tensor_shape_from_node_def_name(graph, input_name):
    """Convenience function to get a shape from a NodeDef's input string."""
    # To get a tensor, the name must be in the form <input>:<port>, for example
    # 'Mul:0'. The GraphDef input strings don't always have the port specified
    # though, so if there isn't a colon we need to add a default ':0' to the end.
    if ":" not in input_name:
        canonical_name = input_name + ":0"
    else:
        canonical_name = input_name
    tensor = graph.get_tensor_by_name(canonical_name)
    shape = tensor.get_shape()
    return shape 
Example #27
Source File: util.py    From tfjs-to-tf with MIT License 5 votes vote down vote up
def _op_nodes(graph_def: GraphDef) -> List[NodeDef]:
    return [node for node in graph_def.node if _is_op_node(node)] 
Example #28
Source File: util.py    From tfjs-to-tf with MIT License 5 votes vote down vote up
def _get_shape(node: NodeDef) -> List[int]:
    def shape(attr): return attr.shape.dim
    def size(dim): return dim.size if dim.size > 0 else None
    return [size(dim) for dim in shape(node.attr[c.TFJS_ATTR_SHAPE_KEY])] 
Example #29
Source File: util.py    From tfjs-to-tf with MIT License 5 votes vote down vote up
def _node_info(node: NodeDef) -> NodeInfo:
    def dtype(n): return _map_type(n.attr[c.TFJS_ATTR_DTYPE_KEY].type)
    return NodeInfo(name=node.name, shape=_get_shape(node), dtype=dtype(node),
                    tensor=node.name + ':0') 
Example #30
Source File: train_shadownet_multi.py    From uai-sdk with Apache License 2.0 5 votes vote down vote up
def local_device_setter(num_devices=1,
                        ps_device_type='cpu',
                        worker_device='/cpu:0',
                        ps_ops=None,
                        ps_strategy=None):
    if ps_ops == None:
        ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']

    if ps_strategy is None:
        ps_strategy = device_setter._RoundRobinStrategy(num_devices)
    if not six.callable(ps_strategy):
        raise TypeError("ps_strategy must be callable")

    def _local_device_chooser(op):
        current_device = pydev.DeviceSpec.from_string(op.device or "")

        node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
        if node_def.op in ps_ops:
            ps_device_spec = pydev.DeviceSpec.from_string(
                  '/{}:{}'.format(ps_device_type, ps_strategy(op)))

            ps_device_spec.merge_from(current_device)
            return ps_device_spec.to_string()
        else:
            worker_device_spec = pydev.DeviceSpec.from_string(worker_device or "")
            worker_device_spec.merge_from(current_device)
            return worker_device_spec.to_string()
    return _local_device_chooser