Python tensorflow.python.ops.control_flow_ops.while_loop() Examples

The following are 30 code examples of tensorflow.python.ops.control_flow_ops.while_loop(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.control_flow_ops , or try the search function .
Example #1
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 6 votes vote down vote up
def generate_infeed_enqueue_ops_and_dequeue_fn(self):
    """Generates infeed enqueue ops and dequeue_fn."""
    # While tf.while_loop is called, the body function, which invokes
    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
    # structure is recorded.
    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
        self._invoke_input_fn_and_record_structure())

    self._validate_input_pipeline()

    def dequeue_fn():
      """dequeue_fn is used by TPU to retrieve the tensors."""
      # In the model-parallel case, both the host-side and device-side
      # computations must agree on the core on which infeed takes place. We
      # choose to perform infeed on logical core 0 of each replica.
      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
      # The unflatten process uses the structure information recorded above.
      return self._inputs_structure_recorder.unflatten_features_and_labels(
          values)

    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) 
Example #2
Source File: tpu_estimator.py    From embedding-as-service with MIT License 6 votes vote down vote up
def generate_infeed_enqueue_ops_and_dequeue_fn(self):
    """Generates infeed enqueue ops and dequeue_fn."""
    # While tf.while_loop is called, the body function, which invokes
    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
    # structure is recorded.
    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
        self._invoke_input_fn_and_record_structure())

    self._validate_input_pipeline()

    def dequeue_fn():
      """dequeue_fn is used by TPU to retrieve the tensors."""
      # In the model-parallel case, both the host-side and device-side
      # computations must agree on the core on which infeed takes place. We
      # choose to perform infeed on logical core 0 of each replica.
      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
      # The unflatten process uses the structure information recorded above.
      return self._inputs_structure_recorder.unflatten_features_and_labels(
          values)

    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) 
Example #3
Source File: tpu_estimator.py    From embedding-as-service with MIT License 6 votes vote down vote up
def _wrap_computation_in_while_loop(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    with ops.control_dependencies(op_fn()):
      return i + 1

  iterations_per_loop_var = _create_or_get_iterations_per_loop()
  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    iterations = array_ops.identity(iterations_per_loop_var)
    return control_flow_ops.while_loop(
        lambda i: i < iterations,
        computation, [constant_op.constant(0)],
        parallel_iterations=1) 
Example #4
Source File: control_flow_ops_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
    for dtype in [dtypes.float32, dtypes.float64]:
      with self.test_session() as sess:
        inputs = tf.placeholder(dtype=dtype)
        initial_outputs = tf.TensorArray(dtype=dtype, dynamic_size=True,
                                         size=1)
        initial_i = tf.constant(0, dtype=dtypes.int32)

        def Cond(i, _):
          return i < tf.size(inputs)  # pylint: disable=cell-var-from-loop

        def Body(i, outputs):
          x = tf.gather(inputs, i)  # pylint: disable=cell-var-from-loop
          outputs = outputs.write(i, x)
          return i + 1, outputs

        _, outputs = tf.while_loop(Cond, Body, [initial_i, initial_outputs])

        outputs = tf.reduce_sum(outputs.pack())
        r = tf.gradients([outputs], [inputs])[0]
        grad_wr_inputs = ops.convert_to_tensor(r)
        o, grad = sess.run([outputs, grad_wr_inputs],
                           feed_dict={inputs: [1, 3, 2]})
        self.assertEquals(o, 6)
        self.assertAllEqual(grad, [1] * 3) 
Example #5
Source File: tpu_estimator.py    From embedding-as-service with MIT License 6 votes vote down vote up
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1) 
Example #6
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 6 votes vote down vote up
def generate_infeed_enqueue_ops_and_dequeue_fn(self):
    """Generates infeed enqueue ops and dequeue_fn."""
    # While tf.while_loop is called, the body function, which invokes
    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
    # structure is recorded.
    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
        self._invoke_input_fn_and_record_structure())

    self._validate_input_pipeline()

    def dequeue_fn():
      """dequeue_fn is used by TPU to retrieve the tensors."""
      # In the model-parallel case, both the host-side and device-side
      # computations must agree on the core on which infeed takes place. We
      # choose to perform infeed on logical core 0 of each replica.
      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
      # The unflatten process uses the structure information recorded above.
      return self._inputs_structure_recorder.unflatten_features_and_labels(
          values)

    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) 
Example #7
Source File: control_flow_ops_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testIndexedSlicesGradient(self):
    with ops.Graph().as_default():
      embedding_matrix = tf.get_variable(
          "embedding_matrix", [5, 5],
          initializer=tf.random_normal_initializer())
      def Cond(it, _):
        return it < 5
      def Body(it, cost):
        embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
        cost += tf.reduce_sum(embedding)
        return it + 1, cost
      _, cost = control_flow_ops.while_loop(
          Cond, Body, [tf.constant(0), tf.constant(0.0)])
      optimizer = momentum.MomentumOptimizer(0.1, 0.9)
      train_op = optimizer.minimize(cost)
      with self.test_session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in range(10):
          sess.run([train_op]) 
Example #8
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    with ops.control_dependencies(op_fn()):
      return i + 1

  iterations_per_loop_var = _create_or_get_iterations_per_loop()
  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    iterations = array_ops.identity(iterations_per_loop_var)
    return control_flow_ops.while_loop(
        lambda i: i < iterations,
        computation, [constant_op.constant(0)],
        parallel_iterations=1) 
Example #9
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1) 
Example #10
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    with ops.control_dependencies(op_fn()):
      return i + 1

  iterations_per_loop_var = _create_or_get_iterations_per_loop()
  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    iterations = array_ops.identity(iterations_per_loop_var)
    return control_flow_ops.while_loop(
        lambda i: i < iterations,
        computation, [constant_op.constant(0)],
        parallel_iterations=1) 
Example #11
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1) 
Example #12
Source File: session_debug_testlib.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 6 votes vote down vote up
def testDebugWhileLoopWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      loop_result, dump = self._debug_run_and_get_dump(sess, loop)
      self.assertEqual(16, loop_result)

      self.assertEqual(
          [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
      self.assertEqual(
          [[12], [14], [16]],
          dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) 
Example #13
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1) 
Example #14
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def computation(i):
    with ops.control_dependencies(op_fn()):
      return i + 1

  iterations_per_loop_var = _create_or_get_iterations_per_loop()
  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    iterations = array_ops.identity(iterations_per_loop_var)
    return control_flow_ops.while_loop(
        lambda i: i < iterations,
        computation, [constant_op.constant(0)],
        parallel_iterations=1) 
Example #15
Source File: GsganDiscriminator.py    From Texygen with MIT License 6 votes vote down vote up
def predict(self, input_x, h_0=None):
        if h_0 is None:
            h_0 = self.h_0
        def _g_recurrence(i, x_t, h_tm1, o_t):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  # hidden_memory_tuple
            o_t = self.g_output_unit(h_t)  # batch x vocab , logits not prob
            x_tp1 = tf.squeeze(tf.slice(input_x, begin=[0, i, 0], size=[self.batch_size_scale, 1, self.num_vocabulary]))
            return i + 1, x_tp1, h_t, o_t

        o_0 = tf.constant(np.zeros(shape=[self.batch_size_scale, self.num_classes]))
        o_0 = tf.cast(o_0, dtype=tf.float32)
        _, _, h_t, output = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_g_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.one_hot, self.start_token), self.h0, o_0))

        return output 
Example #16
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 6 votes vote down vote up
def generate_infeed_enqueue_ops_and_dequeue_fn(self):
    """Generates infeed enqueue ops and dequeue_fn."""
    # While tf.while_loop is called, the body function, which invokes
    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
    # structure is recorded.
    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
        self._invoke_input_fn_and_record_structure())

    self._validate_input_pipeline()

    def dequeue_fn():
      """dequeue_fn is used by TPU to retrieve the tensors."""
      # In the model-parallel case, both the host-side and device-side
      # computations must agree on the core on which infeed takes place. We
      # choose to perform infeed on logical core 0 of each replica.
      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
      # The unflatten process uses the structure information recorded above.
      return self._inputs_structure_recorder.unflatten_features_and_labels(
          values)

    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) 
Example #17
Source File: session_debug_testlib.py    From lambda-packs with MIT License 6 votes vote down vote up
def testDebugWhileLoopWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      loop_body = lambda i: math_ops.add(i, 2)
      loop_cond = lambda i: math_ops.less(i, 16)

      i = constant_op.constant(10, name="i")
      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(run_options,
                              sess.graph,
                              debug_urls=self._debug_urls())
      run_metadata = config_pb2.RunMetadata()
      self.assertEqual(
          16, sess.run(loop, options=run_options, run_metadata=run_metadata))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)

      self.assertEqual(
          [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
      self.assertEqual(
          [[12], [14], [16]],
          dump.get_tensors("while/NextIteration", 0, "DebugIdentity")) 
Example #18
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 5 votes vote down vote up
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  (single_tpu_predict_step, host_calls, captured_scaffold_fn,
   captured_predict_hooks
  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)

  def multi_tpu_predict_steps_on_single_shard():

    def cond(scalar_stopping_signal):
      return math_ops.logical_not(
          _StopSignals.should_stop(scalar_stopping_signal))

    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
    outputs = training_loop.while_loop(
        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
    return outputs

  (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard(
      multi_tpu_predict_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  dummy_predict_op = dummy_predict_op[0]
  scaffold = _get_scaffold(captured_scaffold_fn)
  return (compile_op, dummy_predict_op, host_calls, scaffold,
          captured_predict_hooks.get()) 
Example #19
Source File: resample.py    From lambda-packs with MIT License 5 votes vote down vote up
def _repeat_range(counts, name=None):
  """Repeat integers given by range(len(counts)) each the given number of times.

  Example behavior:
  [0, 1, 2, 3] -> [1, 2, 2, 3, 3, 3]

  Args:
    counts: 1D tensor with dtype=int32.
    name: optional name for operation.

  Returns:
    1D tensor with dtype=int32 and dynamic length giving the repeated integers.
  """
  with ops.name_scope(name, 'repeat_range', [counts]) as scope:
    counts = ops.convert_to_tensor(counts, name='counts')

    def cond(unused_output, i):
      return i < size

    def body(output, i):
      value = array_ops.fill(counts[i:i+1], i)
      return (output.write(i, value), i + 1)

    size = array_ops.shape(counts)[0]
    init_output_array = tensor_array_ops.TensorArray(
        dtype=dtypes.int32, size=size, infer_shape=False)
    output_array, num_writes = control_flow_ops.while_loop(
        cond, body, loop_vars=[init_output_array, 0])

    return control_flow_ops.cond(
        num_writes > 0,
        output_array.concat,
        lambda: array_ops.zeros(shape=[0], dtype=dtypes.int32),
        name=scope) 
Example #20
Source File: __init__.py    From SIMPLE-NN with GNU General Public License v3.0 5 votes vote down vote up
def repeat(x, counts):
    """
    repeat x repeated by counts (elementwise)
    counts must be integer tensor.

    example:
      x = [3.0, 4.0, 5.0, 6.0]
      counts = [3, 1, 0, 2]
      repeat(x, counts)
      >> [3.0, 3.0, 3.0, 4.0, 6.0, 6.0]
    """
    def cond(_, i):
        return i < size

    def body(output, i):
        value = array_ops.fill(counts[i:i+1], x[i])
        return (output.write(i, value), i + 1)

    size = array_ops.shape(counts)[0]
    init_output_array = tensor_array_ops.TensorArray(
        dtype=x.dtype, size=size, infer_shape=False)
    output_array, num_writes = control_flow_ops.while_loop(
        cond, body, loop_vars=[init_output_array, 0])

    return control_flow_ops.cond(
        num_writes > 0,
        output_array.concat,
        lambda: array_ops.zeros(shape=[0], dtype=x.dtype)) 
Example #21
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 5 votes vote down vote up
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  (single_tpu_predict_step, host_calls, captured_scaffold_fn,
   captured_predict_hooks
  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)

  def multi_tpu_predict_steps_on_single_shard():

    def cond(scalar_stopping_signal):
      return math_ops.logical_not(
          _StopSignals.should_stop(scalar_stopping_signal))

    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
    outputs = training_loop.while_loop(
        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
    return outputs

  (dummy_predict_op,) = tpu.shard(
      multi_tpu_predict_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  scaffold = _get_scaffold(captured_scaffold_fn)
  return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() 
Example #22
Source File: RankganDiscriminator.py    From Texygen with MIT License 5 votes vote down vote up
def get_rank_score(emb_test, embs_ref):
    p = embs_ref.shape
    ref_size = p.as_list()[0]

    def _loop_body(i, ret_v, emb_test, embs_ref):
        return i + 1, ret_v + cosine_distance(emb_test, tf.nn.embedding_lookup(embs_ref, i)), emb_test, embs_ref

    _, ret, _, _ = control_flow_ops.while_loop(
        cond=lambda i, _1, _2, _3: i < ref_size,
        body=_loop_body,
        loop_vars=(tf.constant(0, dtype=tf.int32), tf.constant(0.0, dtype=tf.float32), emb_test, embs_ref)
    )
    return ret / ref_size 
Example #23
Source File: analyzer_cli_test.py    From keras-lambda with MIT License 5 votes vote down vote up
def setUpClass(cls):
    cls._dump_root = tempfile.mkdtemp()

    with session.Session() as sess:
      loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
      cond = lambda loop_var: math_ops.less(loop_var, 10)
      body = lambda loop_var: math_ops.add(loop_var, 1)
      while_loop = control_flow_ops.while_loop(
          cond, body, [loop_var], parallel_iterations=1)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_url = "file://%s" % cls._dump_root

      watch_opts = run_options.debug_options.debug_tensor_watch_opts

      # Add debug tensor watch for "while/Identity".
      watch = watch_opts.add()
      watch.node_name = "while/Identity"
      watch.output_slot = 0
      watch.debug_ops.append("DebugIdentity")
      watch.debug_urls.append(debug_url)

      # Invoke Session.run().
      run_metadata = config_pb2.RunMetadata()
      sess.run(while_loop, options=run_options, run_metadata=run_metadata)

    cls._debug_dump = debug_data.DebugDumpDir(
        cls._dump_root, partition_graphs=run_metadata.partition_graphs)

    cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
    cls._registry = debugger_cli_common.CommandHandlerRegistry()
    cls._registry.register_command_handler(
        "list_tensors",
        cls._analyzer.list_tensors,
        cls._analyzer.get_help("list_tensors"),
        prefix_aliases=["lt"])
    cls._registry.register_command_handler(
        "print_tensor",
        cls._analyzer.print_tensor,
        cls._analyzer.get_help("print_tensor"),
        prefix_aliases=["pt"]) 
Example #24
Source File: tpu_estimator.py    From embedding-as-service with MIT License 5 votes vote down vote up
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  (single_tpu_predict_step, host_calls, captured_scaffold_fn,
   captured_predict_hooks
  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)

  def multi_tpu_predict_steps_on_single_shard():

    def cond(scalar_stopping_signal):
      return math_ops.logical_not(
          _StopSignals.should_stop(scalar_stopping_signal))

    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
    outputs = training_loop.while_loop(
        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
    return outputs

  (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard(
      multi_tpu_predict_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  dummy_predict_op = dummy_predict_op[0]
  scaffold = _get_scaffold(captured_scaffold_fn)
  return (compile_op, dummy_predict_op, host_calls, scaffold,
          captured_predict_hooks.get()) 
Example #25
Source File: von_mises_fisher.py    From s-vae-tf with MIT License 5 votes vote down vote up
def __while_loop(self, b, a, d, n, seed):
        def __cond(w, e, bool_mask, b, a, d):
            return math_ops.reduce_any(bool_mask)

        def __body(w_, e_, bool_mask, b, a, d):
            e = math_ops.cast(Beta((self.__mf - 1) / 2, (self.__mf - 1) / 2).sample(
                shape, seed=seed), dtype=self.dtype)

            u = random_ops.random_uniform(shape, dtype=self.dtype, seed=seed)

            w = (1 - (1 + b) * e) / (1 - (1 - b) * e)
            t = (2 * a * b) / (1 - (1 - b) * e)

            accept = gen_math_ops.greater(((self.__mf - 1) * math_ops.log(t) - t + d), math_ops.log(u))
            reject = gen_math_ops.logical_not(accept)

            w_ = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), w, w_)
            e_ = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), e, e_)
            bool_mask = array_ops.where(gen_math_ops.logical_and(bool_mask, accept), reject, bool_mask)

            return w_, e_, bool_mask, b, a, d

        shape = array_ops.concat([[n], self.batch_shape_tensor()[:-1], [1]], 0)
        b, a, d = [gen_array_ops.tile(array_ops.expand_dims(e, axis=0), [n] + [1] * len(e.shape)) for e in (b, a, d)]

        w, e, bool_mask, b, a, d = control_flow_ops.while_loop(__cond, __body,
                                                               [array_ops.zeros_like(b, dtype=self.dtype),
                                                                array_ops.zeros_like(b, dtype=self.dtype),
                                                                array_ops.ones_like(b, dtypes.bool),
                                                                b, a, d])

        return e, w 
Example #26
Source File: analyzer_cli_test.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def setUpClass(cls):
    cls._dump_root = tempfile.mkdtemp()

    with session.Session() as sess:
      loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
      cond = lambda loop_var: math_ops.less(loop_var, 10)
      body = lambda loop_var: math_ops.add(loop_var, 1)
      while_loop = control_flow_ops.while_loop(
          cond, body, [loop_var], parallel_iterations=1)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_url = "file://%s" % cls._dump_root

      watch_opts = run_options.debug_options.debug_tensor_watch_opts

      # Add debug tensor watch for "while/Identity".
      watch = watch_opts.add()
      watch.node_name = "while/Identity"
      watch.output_slot = 0
      watch.debug_ops.append("DebugIdentity")
      watch.debug_urls.append(debug_url)

      # Invoke Session.run().
      run_metadata = config_pb2.RunMetadata()
      sess.run(while_loop, options=run_options, run_metadata=run_metadata)

    cls._debug_dump = debug_data.DebugDumpDir(
        cls._dump_root, partition_graphs=run_metadata.partition_graphs)

    cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
    cls._registry = debugger_cli_common.CommandHandlerRegistry()
    cls._registry.register_command_handler(
        "list_tensors",
        cls._analyzer.list_tensors,
        cls._analyzer.get_help("list_tensors"),
        prefix_aliases=["lt"])
    cls._registry.register_command_handler(
        "print_tensor",
        cls._analyzer.print_tensor,
        cls._analyzer.get_help("print_tensor"),
        prefix_aliases=["pt"]) 
Example #27
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 5 votes vote down vote up
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
  (single_tpu_predict_step, host_calls, captured_scaffold_fn,
   captured_predict_hooks
  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)

  def multi_tpu_predict_steps_on_single_shard():

    def cond(scalar_stopping_signal):
      return math_ops.logical_not(
          _StopSignals.should_stop(scalar_stopping_signal))

    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
    outputs = training_loop.while_loop(
        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
    return outputs

  (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard(
      multi_tpu_predict_steps_on_single_shard,
      inputs=[],
      num_shards=ctx.num_replicas,
      outputs_from_all_shards=False,
      device_assignment=ctx.device_assignment)

  dummy_predict_op = dummy_predict_op[0]
  scaffold = _get_scaffold(captured_scaffold_fn)
  return (compile_op, dummy_predict_op, host_calls, scaffold,
          captured_predict_hooks.get()) 
Example #28
Source File: metric_learning.py    From tf-slim with Apache License 2.0 5 votes vote down vote up
def compute_gt_cluster_score(pairwise_distances, labels):
  """Compute ground truth facility location score.

  Loop over each unique classes and compute average travel distances.

  Args:
    pairwise_distances: 2-D Tensor of pairwise distances.
    labels: 1-D Tensor of ground truth cluster assignment.

  Returns:
    gt_cluster_score: dtypes.float32 score.
  """
  unique_class_ids = array_ops.unique(labels)[0]
  num_classes = array_ops.size(unique_class_ids)
  iteration = array_ops.constant(0)
  gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32)

  def func_cond(iteration, gt_cluster_score):
    del gt_cluster_score  # Unused argument.
    return iteration < num_classes

  def func_body(iteration, gt_cluster_score):
    """Per each cluster, compute the average travel distance."""
    mask = math_ops.equal(labels, unique_class_ids[iteration])
    this_cluster_ids = array_ops.where(mask)
    pairwise_distances_subset = array_ops.transpose(
        array_ops.gather(
            array_ops.transpose(
                array_ops.gather(pairwise_distances, this_cluster_ids)),
            this_cluster_ids))
    this_cluster_score = -1.0 * math_ops.reduce_min(
        math_ops.reduce_sum(
            pairwise_distances_subset, axis=0))
    return iteration + 1, gt_cluster_score + this_cluster_score

  _, gt_cluster_score = control_flow_ops.while_loop(
      func_cond, func_body, [iteration, gt_cluster_score])
  return gt_cluster_score 
Example #29
Source File: metric_loss_ops.py    From cluster-loss-tensorflow with BSD 2-Clause "Simplified" License 5 votes vote down vote up
def compute_gt_cluster_score(pairwise_distances, labels):
  """Compute ground truth facility location score.

  Loop over each unique classes and compute average travel distances.

  Args:
    pairwise_distances: 2-D Tensor of pairwise distances.
    labels: 1-D Tensor of ground truth cluster assignment.

  Returns:
    gt_cluster_score: dtypes.float32 score.
  """
  unique_class_ids = array_ops.unique(labels)[0]
  num_classes = array_ops.size(unique_class_ids)
  iteration = array_ops.constant(0)
  gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32)

  def func_cond(iteration, gt_cluster_score):
    del gt_cluster_score  # Unused argument.
    return iteration < num_classes

  def func_body(iteration, gt_cluster_score):
    """Per each cluster, compute the average travel distance."""
    mask = math_ops.equal(labels, unique_class_ids[iteration])
    this_cluster_ids = array_ops.where(mask)
    pairwise_distances_subset = array_ops.transpose(
        array_ops.gather(
            array_ops.transpose(
                array_ops.gather(pairwise_distances, this_cluster_ids)),
            this_cluster_ids))
    this_cluster_score = -1.0 * math_ops.reduce_min(
        math_ops.reduce_sum(
            pairwise_distances_subset, axis=0))
    return iteration + 1, gt_cluster_score + this_cluster_score

  _, gt_cluster_score = control_flow_ops.while_loop(
      func_cond, func_body, [iteration, gt_cluster_score])
  return gt_cluster_score 
Example #30
Source File: analyzer_cli_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def setUpClass(cls):
    cls._dump_root = tempfile.mkdtemp()

    with session.Session() as sess:
      loop_var = constant_op.constant(0, name="while_loop_test/loop_var")
      cond = lambda loop_var: math_ops.less(loop_var, 10)
      body = lambda loop_var: math_ops.add(loop_var, 1)
      while_loop = control_flow_ops.while_loop(
          cond, body, [loop_var], parallel_iterations=1)

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_url = "file://%s" % cls._dump_root

      watch_opts = run_options.debug_tensor_watch_opts

      # Add debug tensor watch for "while/Identity".
      watch = watch_opts.add()
      watch.node_name = "while/Identity"
      watch.output_slot = 0
      watch.debug_ops.append("DebugIdentity")
      watch.debug_urls.append(debug_url)

      # Invoke Session.run().
      run_metadata = config_pb2.RunMetadata()
      sess.run(while_loop, options=run_options, run_metadata=run_metadata)

    cls._debug_dump = debug_data.DebugDumpDir(
        cls._dump_root, partition_graphs=run_metadata.partition_graphs)

    cls._analyzer = analyzer_cli.DebugAnalyzer(cls._debug_dump)
    cls._registry = debugger_cli_common.CommandHandlerRegistry()
    cls._registry.register_command_handler(
        "list_tensors",
        cls._analyzer.list_tensors,
        cls._analyzer.get_help("list_tensors"),
        prefix_aliases=["lt"])
    cls._registry.register_command_handler(
        "print_tensor",
        cls._analyzer.print_tensor,
        cls._analyzer.get_help("print_tensor"),
        prefix_aliases=["pt"])