Python tensorflow.python.ops.array_ops.placeholder() Examples

The following are 30 code examples of tensorflow.python.ops.array_ops.placeholder(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.array_ops , or try the search function .
Example #1
Source File: function.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def create_op(self, op_type, inputs, data_types, **kwargs):
    for i, x in enumerate(inputs):
      if x.graph is not self:
        # Referring to a tensor from other graph.
        if x in self._captured:
          # Captured already.
          inputs[i] = self._captured[x]
        else:
          # Substitute with a placeholder.
          self.extra_inputs.append(x)
          ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
          inputs[i] = ph
          self._captured[x] = ph
          self.extra_args.append(ph)
    return super(_FuncGraph, self).create_op(op_type, inputs, data_types,
                                             **kwargs) 
Example #2
Source File: backend.py    From lambda-packs with MIT License 6 votes vote down vote up
def ndim(x):
  """Returns the number of axes in a tensor, as an integer.

  Arguments:
      x: Tensor or variable.

  Returns:
      Integer (scalar), number of axes.

  Examples:
  ```python
      >>> from keras import backend as K
      >>> input = K.placeholder(shape=(2, 4, 5))
      >>> val = np.array([[1, 2], [3, 4]])
      >>> kvar = K.variable(value=val)
      >>> K.ndim(input)
      3
      >>> K.ndim(kvar)
      2
  ```
  """
  dims = x.get_shape()._dims
  if dims is not None:
    return len(dims)
  return None 
Example #3
Source File: io_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def placeholder(dtype, axes, name=None):
  """Create a placeholder for a labeled tensor.

  For example:

    lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])

  See tf.placeholder for more details.

  Args:
    dtype: The type of elements in the tensor to be fed.
    axes: sequence of strings (denoting axes of unknown size) and/or objects
      convertable to lt.Axis to label the result.
    name: Optional op name.

  Returns:
    Placeholder labeled tensor.
  """
  with ops.name_scope(name, 'lt_placeholder', []) as scope:
    axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis
                      for axis in axes])
    shape = [axis.size for axis in axes.values()]
    tensor = array_ops.placeholder(dtype, shape, name=scope)
    return core.LabeledTensor(tensor, axes) 
Example #4
Source File: exporter_test.py    From vehicle_counting_tensorflow with MIT License 6 votes vote down vote up
def test_rewrite_nn_resize_op(self):
    g = tf.Graph()
    with g.as_default():
      x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
      y = array_ops.placeholder(dtypes.float32, shape=(8, 20, 20, 8))
      s = ops.nearest_neighbor_upsampling(x, 2)
      t = s + y
      exporter.rewrite_nn_resize_op()

    resize_op_found = False
    for op in g.get_operations():
      if op.type == 'ResizeNearestNeighbor':
        resize_op_found = True
        self.assertEqual(op.inputs[0], x)
        self.assertEqual(op.outputs[0].consumers()[0], t.op)
        break

    self.assertTrue(resize_op_found) 
Example #5
Source File: input_fn_utils.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def build_parsing_serving_input_fn(feature_spec, default_batch_size=1):
  """Build an input_fn appropriate for serving, expecting fed tf.Examples.

  Creates an input_fn that expects a serialized tf.Example fed into a string
  placeholder.  The function parses the tf.Example according to the provided
  feature_spec, and returns all parsed Tensors as features.  This input_fn is
  for use at serving time, so the labels return value is always None.

  Args:
    feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`.
    default_batch_size: the number of query examples expected per batch.

  Returns:
    An input_fn suitable for use in serving.
  """
  def input_fn():
    """An input_fn that expects a serialized tf.Example."""
    serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
                                                  shape=[default_batch_size],
                                                  name='input_example_tensor')
    inputs = {'examples': serialized_tf_example}
    features = parsing_ops.parse_example(serialized_tf_example, feature_spec)
    labels = None  # these are not known in serving!
    return InputFnOps(features, labels, inputs)
  return input_fn 
Example #6
Source File: session_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def delete_session_tensor(handle, name=None):
  """Delete the tensor for the given tensor handle.

  This is EXPERIMENTAL and subject to change.

  Delete the tensor of a given tensor handle. The tensor is produced
  in a previous run() and stored in the state of the session.

  Args:
    handle: The string representation of a persistent tensor handle.
    name: Optional name prefix for the return tensor.

  Returns:
    A pair of graph elements. The first is a placeholder for feeding a
    tensor handle and the second is a deletion operation.
  """
  handle_device = TensorHandle._get_device_name(handle)
  with ops.device(handle_device):
    holder = array_ops.placeholder(dtypes.string)
    deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name)
  return (holder, deleter) 
Example #7
Source File: util.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
  """Create a tf.placeholder for the Graph Editor.

  Note that the correct graph scope must be set by the calling function.
  The placeholder is named using the function placeholder_name (with no
  tensor argument).

  Args:
    dtype: the tensor type.
    shape: the tensor shape (optional).
    scope: absolute scope within which to create the placeholder. None
      means that the scope of t is preserved. "" means the root scope.
  Returns:
    A newly created tf.placeholder.
  """
  return tf_array_ops.placeholder(
      dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) 
Example #8
Source File: util.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def make_placeholder_from_tensor(t, scope=None):
  """Create a `tf.placeholder` for the Graph Editor.

  Note that the correct graph scope must be set by the calling function.

  Args:
    t: a `tf.Tensor` whose name will be used to create the placeholder
      (see function placeholder_name).
    scope: absolute scope within which to create the placeholder. None
      means that the scope of `t` is preserved. `""` means the root scope.
  Returns:
    A newly created `tf.placeholder`.
  Raises:
    TypeError: if `t` is not `None` or a `tf.Tensor`.
  """
  return tf_array_ops.placeholder(
      dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(
          t, scope=scope)) 
Example #9
Source File: tensor_signature.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def create_placeholders_from_signatures(signatures):
  """Creates placeholders from given signatures.

  Args:
    signatures: Dict of `TensorSignature` objects or single `TensorSignature`,
      or `None`.

  Returns:
    Dict of `tf.placeholder` objects or single `tf.placeholder`, or `None`.
  """
  if signatures is None:
    return None
  if not isinstance(signatures, dict):
    return signatures.get_placeholder()
  return {
      key: signatures[key].get_placeholder()
      for key in signatures} 
Example #10
Source File: inception_v2_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def testUnknownImageShape(self):
    ops.reset_default_graph()
    batch_size = 2
    height, width = 224, 224
    num_classes = 1000
    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
    with self.test_session() as sess:
      inputs = array_ops.placeholder(
          dtypes.float32, shape=(batch_size, None, None, 3))
      logits, end_points = inception_v2.inception_v2(inputs, num_classes)
      self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
      self.assertListEqual(logits.get_shape().as_list(),
                           [batch_size, num_classes])
      pre_pool = end_points['Mixed_5c']
      feed_dict = {inputs: input_np}
      variables.global_variables_initializer().run()
      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024]) 
Example #11
Source File: resnet_v1_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def create_test_input(batch_size, height, width, channels):
  """Create test input tensor.

  Args:
    batch_size: The number of images per batch or `None` if unknown.
    height: The height of each image or `None` if unknown.
    width: The width of each image or `None` if unknown.
    channels: The number of channels per image or `None` if unknown.

  Returns:
    Either a placeholder `Tensor` of dimension
      [batch_size, height, width, channels] if any of the inputs are `None` or a
    constant `Tensor` with the mesh grid values along the spatial dimensions.
  """
  if None in [batch_size, height, width, channels]:
    return array_ops.placeholder(dtypes.float32,
                                 (batch_size, height, width, channels))
  else:
    return math_ops.to_float(
        np.tile(
            np.reshape(
                np.reshape(np.arange(height), [height, 1]) + np.reshape(
                    np.arange(width), [1, width]), [1, height, width, 1]),
            [batch_size, 1, 1, channels])) 
Example #12
Source File: backend.py    From lambda-packs with MIT License 6 votes vote down vote up
def set_value(x, value):
  """Sets the value of a variable, from a Numpy array.

  Arguments:
      x: Tensor to set to a new value.
      value: Value to set the tensor to, as a Numpy array
          (of the same shape).
  """
  value = np.asarray(value)
  tf_dtype = _convert_string_dtype(x.dtype.name.split('_')[0])
  if hasattr(x, '_assign_placeholder'):
    assign_placeholder = x._assign_placeholder
    assign_op = x._assign_op
  else:
    assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape)
    assign_op = x.assign(assign_placeholder)
    x._assign_placeholder = assign_placeholder
    x._assign_op = assign_op
  get_session().run(assign_op, feed_dict={assign_placeholder: value}) 
Example #13
Source File: resnet_v2_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def create_test_input(batch_size, height, width, channels):
  """Create test input tensor.

  Args:
    batch_size: The number of images per batch or `None` if unknown.
    height: The height of each image or `None` if unknown.
    width: The width of each image or `None` if unknown.
    channels: The number of channels per image or `None` if unknown.

  Returns:
    Either a placeholder `Tensor` of dimension
      [batch_size, height, width, channels] if any of the inputs are `None` or a
    constant `Tensor` with the mesh grid values along the spatial dimensions.
  """
  if None in [batch_size, height, width, channels]:
    return array_ops.placeholder(dtypes.float32,
                                 (batch_size, height, width, channels))
  else:
    return math_ops.to_float(
        np.tile(
            np.reshape(
                np.reshape(np.arange(height), [height, 1]) + np.reshape(
                    np.arange(width), [1, width]), [1, height, width, 1]),
            [batch_size, 1, 1, channels])) 
Example #14
Source File: function.py    From lambda-packs with MIT License 6 votes vote down vote up
def create_op(self, op_type, inputs, data_types, **kwargs):
    for i, x in enumerate(inputs):
      if x.graph is not self:
        # Referring to a tensor from other graph.
        if x in self._captured:
          # Captured already.
          inputs[i] = self._captured[x]
        elif self._capture_by_value:
          inputs[i] = self._add_tensor_and_parents(x)
        else:
          # Substitute with a placeholder.
          self.extra_inputs.append(x)
          ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
          # pylint: disable=protected-access
          ph._handle_shape = x._handle_shape
          ph._handle_dtype = x._handle_dtype
          # pylint: enable=protected-access
          inputs[i] = ph
          self._captured[x] = ph
          self.extra_args.append(ph)
    return super(_ExperimentalFuncGraph, self).create_op(op_type, inputs,
                                                         data_types, **kwargs) 
Example #15
Source File: function.py    From lambda-packs with MIT License 6 votes vote down vote up
def _add_op_and_parents(self, op):
    op_def = function._get_op_def(op)
    if op_def.is_stateful:
      raise ValueError("Cannot capture a stateful node by value.")
    elif op.type in ("Placeholder", "PlaceholderV2"):
      raise ValueError("Cannot capture a placeholder by value.")

    captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs]

    captured_op = self.create_op(op.type, captured_inputs,
                                 [o.dtype for o in op.outputs],
                                 name=op.name, attrs=op.node_def.attr,
                                 op_def=op_def)

    for t, captured_t in zip(op.outputs, captured_op.outputs):
      self._captured[t] = captured_t

    return captured_op 
Example #16
Source File: backend.py    From lambda-packs with MIT License 6 votes vote down vote up
def function(inputs, outputs, updates=None, **kwargs):
  """Instantiates a Keras function.

  Arguments:
      inputs: List of placeholder tensors.
      outputs: List of output tensors.
      updates: List of update ops.
      **kwargs: Not used with TensorFlow.

  Returns:
      Output values as Numpy arrays.
  """
  if kwargs:
    msg = [
        'Expected no kwargs, you passed %s' % len(kwargs),
        'kwargs passed to function are ignored with Tensorflow backend'
    ]
    warnings.warn('\n'.join(msg))
  return Function(inputs, outputs, updates=updates) 
Example #17
Source File: util.py    From lambda-packs with MIT License 6 votes vote down vote up
def make_placeholder_from_tensor(t, scope=None):
  """Create a `tf.placeholder` for the Graph Editor.

  Note that the correct graph scope must be set by the calling function.

  Args:
    t: a `tf.Tensor` whose name will be used to create the placeholder
      (see function placeholder_name).
    scope: absolute scope within which to create the placeholder. None
      means that the scope of `t` is preserved. `""` means the root scope.
  Returns:
    A newly created `tf.placeholder`.
  Raises:
    TypeError: if `t` is not `None` or a `tf.Tensor`.
  """
  return tf_array_ops.placeholder(
      dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(
          t, scope=scope)) 
Example #18
Source File: util.py    From lambda-packs with MIT License 6 votes vote down vote up
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
  """Create a tf.placeholder for the Graph Editor.

  Note that the correct graph scope must be set by the calling function.
  The placeholder is named using the function placeholder_name (with no
  tensor argument).

  Args:
    dtype: the tensor type.
    shape: the tensor shape (optional).
    scope: absolute scope within which to create the placeholder. None
      means that the scope of t is preserved. "" means the root scope.
  Returns:
    A newly created tf.placeholder.
  """
  return tf_array_ops.placeholder(
      dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) 
Example #19
Source File: session_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def delete_session_tensor(handle, name=None):
  """Delete the tensor for the given tensor handle.

  This is EXPERIMENTAL and subject to change.

  Delete the tensor of a given tensor handle. The tensor is produced
  in a previous run() and stored in the state of the session.

  Args:
    handle: The string representation of a persistent tensor handle.
    name: Optional name prefix for the return tensor.

  Returns:
    A pair of graph elements. The first is a placeholder for feeding a
    tensor handle and the second is a deletion operation.
  """
  handle_device = TensorHandle._get_device_name(handle)
  with ops.device(handle_device):
    holder = array_ops.placeholder(dtypes.string)
    deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name)
  return (holder, deleter) 
Example #20
Source File: tensor_signature.py    From lambda-packs with MIT License 6 votes vote down vote up
def create_placeholders_from_signatures(signatures):
  """Creates placeholders from given signatures.

  Args:
    signatures: Dict of `TensorSignature` objects or single `TensorSignature`,
      or `None`.

  Returns:
    Dict of `tf.placeholder` objects or single `tf.placeholder`, or `None`.
  """
  if signatures is None:
    return None
  if not isinstance(signatures, dict):
    return signatures.get_placeholder()
  return {
      key: signatures[key].get_placeholder()
      for key in signatures} 
Example #21
Source File: feature_column.py    From lambda-packs with MIT License 6 votes vote down vote up
def make_place_holder_tensors_for_base_features(feature_columns):
  """Returns placeholder tensors for inference.

  Args:
    feature_columns: An iterable containing all the feature columns. All items
      should be instances of classes derived from _FeatureColumn.
  Returns:
    A dict mapping feature keys to SparseTensors (sparse columns) or
    placeholder Tensors (dense columns).
  """
  # Get dict mapping features to FixedLenFeature or VarLenFeature values.
  dict_for_parse_example = create_feature_spec_for_parsing(feature_columns)
  placeholders = {}
  for column_name, column_type in dict_for_parse_example.items():
    if isinstance(column_type, parsing_ops.VarLenFeature):
      # Sparse placeholder for sparse tensors.
      placeholders[column_name] = array_ops.sparse_placeholder(
          column_type.dtype, name="Placeholder_{}".format(column_name))
    else:
      # Simple placeholder for dense tensors.
      placeholders[column_name] = array_ops.placeholder(
          column_type.dtype,
          shape=(None, column_type.shape[0]),
          name="Placeholder_{}".format(column_name))
  return placeholders 
Example #22
Source File: io_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def placeholder(dtype, axes, name=None):
  """Create a placeholder for a labeled tensor.

  For example:

    lt.placeholder(tf.float32, ['batch', ('channel', ['r', 'g', 'b'])])

  See tf.placeholder for more details.

  Args:
    dtype: The type of elements in the tensor to be fed.
    axes: sequence of strings (denoting axes of unknown size) and/or objects
      convertable to lt.Axis to label the result.
    name: Optional op name.

  Returns:
    Placeholder labeled tensor.
  """
  with ops.name_scope(name, 'lt_placeholder', []) as scope:
    axes = core.Axes([(axis, None) if isinstance(axis, string_types) else axis
                      for axis in axes])
    shape = [axis.size for axis in axes.values()]
    tensor = array_ops.placeholder(dtype, shape, name=scope)
    return core.LabeledTensor(tensor, axes) 
Example #23
Source File: exporter_test.py    From vehicle_counting_tensorflow with MIT License 6 votes vote down vote up
def _save_checkpoint_from_mock_model(self,
                                       checkpoint_path,
                                       use_moving_averages,
                                       enable_quantization=False):
    g = tf.Graph()
    with g.as_default():
      mock_model = FakeModel()
      preprocessed_inputs, true_image_shapes = mock_model.preprocess(
          tf.placeholder(tf.float32, shape=[None, None, None, 3]))
      predictions = mock_model.predict(preprocessed_inputs, true_image_shapes)
      mock_model.postprocess(predictions, true_image_shapes)
      if use_moving_averages:
        tf.train.ExponentialMovingAverage(0.0).apply()
      tf.train.get_or_create_global_step()
      if enable_quantization:
        graph_rewriter_config = graph_rewriter_pb2.GraphRewriter()
        graph_rewriter_config.quantization.delay = 500000
        graph_rewriter_fn = graph_rewriter_builder.build(
            graph_rewriter_config, is_training=False)
        graph_rewriter_fn()
      saver = tf.train.Saver()
      init = tf.global_variables_initializer()
      with self.test_session() as sess:
        sess.run(init)
        saver.save(sess, checkpoint_path) 
Example #24
Source File: summaries.py    From lambda-packs with MIT License 6 votes vote down vote up
def tf_spec_summary(spec,
                    inputs=None,
                    input_shape=None,
                    input_type=dtypes.float32):
  """Output a summary of the specification.

  This prints a list of left-most tensor operations and summarized the
  variables found in the right branches. This kind of representation
  is particularly useful for networks that are generally structured
  like pipelines.

  Args:
      spec: specification
      inputs: input to the spec construction (usually a Tensor)
      input_shape: optional shape of input
      input_type: type of the input tensor
  """

  if inputs is None:
    inputs = array_ops.placeholder(input_type, input_shape)
  outputs = specs.create_net(spec, inputs)
  tf_parameter_summary(outputs) 
Example #25
Source File: summaries.py    From lambda-packs with MIT License 6 votes vote down vote up
def tf_spec_print(spec,
                  inputs=None,
                  input_shape=None,
                  input_type=dtypes.float32):
  """Print a tree representing the spec.

  Args:
      spec: specification
      inputs: input to the spec construction (usually a Tensor)
      input_shape: optional shape of input
      input_type: type of the input tensor
  """

  if inputs is None:
    inputs = array_ops.placeholder(input_type, input_shape)
  outputs = specs.create_net(spec, inputs)
  tf_print(outputs) 
Example #26
Source File: stepper_cli_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def setUp(self):
    self.a = variables.Variable(10.0, name="a")
    self.b = variables.Variable(20.0, name="b")

    self.c = math_ops.add(self.a, self.b, name="c")  # Should be 30.0.
    self.d = math_ops.subtract(self.a, self.c, name="d")  # Should be -20.0.
    self.e = math_ops.multiply(self.c, self.d, name="e")  # Should be -600.0.

    self.ph = array_ops.placeholder(dtypes.float32, shape=(2, 2), name="ph")
    self.f = math_ops.multiply(self.e, self.ph, name="f")

    self.opt = gradient_descent.GradientDescentOptimizer(0.1).minimize(
        self.e, name="opt")

    self.sess = session.Session()

    self.sess.run(self.a.initializer)
    self.sess.run(self.b.initializer) 
Example #27
Source File: local_cli_wrapper_test.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def setUp(self):
    self._tmp_dir = tempfile.mktemp()

    self.v = variables.Variable(10.0, name="v")
    self.delta = constant_op.constant(1.0, name="delta")
    self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")

    self.ph = array_ops.placeholder(dtypes.float32, name="ph")
    self.xph = array_ops.transpose(self.ph, name="xph")
    self.m = constant_op.constant(
        [[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")
    self.y = math_ops.matmul(self.m, self.xph, name="y")

    self.sess = session.Session()

    # Initialize variable.
    self.sess.run(self.v.initializer) 
Example #28
Source File: batch_ops_test.py    From lambda-packs with MIT License 6 votes vote down vote up
def testBasicUnbatch(self):
    """Tests that batch and unbatch work together."""
    with self.test_session() as sess:
      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
      batched, index, id_t = batch_ops.batch(
          [inp], num_batch_threads=1, max_batch_size=10,
          batch_timeout_micros=100000,  # 100ms
          allowed_batch_sizes=[3, 10],
          grad_timeout_micros=0, batching_queue="")
      computation = batched[0] + 1
      result = batch_ops.unbatch(computation, index, id_t,
                                 timeout_micros=1000000, shared_name="unbatch")
      thread_results = []

      def worker():
        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))

      worker_thread = threading.Thread(target=worker)
      worker_thread.start()
      main_results = sess.run([result], feed_dict={inp: [2]})
      worker_thread.join()
      self.assertEqual(thread_results[0], [2])
      self.assertEqual(main_results[0], [3]) 
Example #29
Source File: batch_ops_test.py    From lambda-packs with MIT License 6 votes vote down vote up
def testBasicUnbatchDecorated(self):
    """Tests that the batch_function decorator works."""
    with self.test_session() as sess:
      @batch_ops.batch_function(1, 10, 100000)
      def computation(in_t):
        return in_t + 1
      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
      result = computation(inp)
      thread_results = []

      def worker():
        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))

      worker_thread = threading.Thread(target=worker)
      worker_thread.start()
      main_results = sess.run([result], feed_dict={inp: [2]})
      worker_thread.join()
      self.assertEqual(thread_results[0], [2])
      self.assertEqual(main_results[0], [3]) 
Example #30
Source File: batch_ops_test.py    From lambda-packs with MIT License 6 votes vote down vote up
def testUnbatchGrad(self):
    """Tests that batch and unbatch are differentiable."""
    with self.test_session() as sess:
      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
      batched, index, id_t = batch_ops.batch(
          [inp], num_batch_threads=1, max_batch_size=2,
          batch_timeout_micros=36000000, grad_timeout_micros=1000000,
          batching_queue="")
      computation = batched[0] * batched[0]
      result = batch_ops.unbatch(computation, index, id_t,
                                 timeout_micros=1000000, shared_name="unbatch")
      grad = gradients_impl.gradients(result, inp)
      thread_results = []

      def worker():
        thread_results.extend(sess.run([grad], feed_dict={inp: [1]}))

      worker_thread = threading.Thread(target=worker)
      worker_thread.start()
      main_results = sess.run([grad], feed_dict={inp: [2]})
      worker_thread.join()
      self.assertEqual(thread_results[0], [2])
      self.assertEqual(main_results[0], [4])