Python tensorflow.compat.v1.while_loop() Examples

The following are 22 code examples of tensorflow.compat.v1.while_loop(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def test_cond_in_loop():
    graph = tf.Graph()
    with graph.as_default():
        def body(x):
            x = tf.constant(7)
            z = tf.constant(20)
            res = tf.cond(tf.less(x, 10), lambda: tf.add(
                10, 20), lambda: tf.square(10))
            return tf.multiply(res, x)

        x = tf.constant(21)
        def condition(x):
            return tf.less(x, 100)

        r = tf.while_loop(condition, body, loop_vars=[x])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
Example #2
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def test_loop_in_cond():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(a, b):
            i = tf.constant(0)

            def cd(i): return tf.less(i, 10)

            def bd(i): return tf.add(i, 1)
            res = tf.while_loop(cd, bd, [i])
            return tf.multiply(tf.add(20, res), 10)

        def fn2(a, b):
            return tf.add(10, 20)

        x = tf.constant(7)
        y = tf.constant(20)
        z = tf.constant(10)
        pred = tf.less(x, y)
        r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})

    check_equal(graph, tf_out) 
Example #3
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def test_nested_loop():
    graph = tf.Graph()
    with graph.as_default():

        def body(x):
            def nest_body(c):
                return tf.multiply(c, 2)
            def cd(c): return tf.less(c, 10)
            c = tf.constant(2)
            res = tf.while_loop(cd, nest_body, loop_vars=[c])
            return tf.nn.relu(x + res)

        def condition(x):
            return tf.greater(x, 100)
        x = tf.constant(3)
        r = tf.while_loop(condition, body, loop_vars=[x])

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
Example #4
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def test_loop_conditions():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(1)
        j = tf.constant(1)
        k = tf.constant(5)

        def c(i, j, k): return \
            tf.equal(tf.not_equal(tf.less(i + j, 10),
                                  tf.less(j * k, 100)),
                     tf.greater_equal(k, i + j))

        def b(i, j, k): return [i+j, j+k, k+1]
        r = tf.while_loop(c, b, loop_vars=[i, j, k])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
Example #5
Source File: attacks.py    From interval-bound-propagation with Apache License 2.0 6 votes vote down vote up
def _build(self, inputs, labels):

    def cond(i, unused_attack, success):
      # If we are already successful, we break.
      return tf.logical_and(i < self._num_restarts,
                            tf.logical_not(tf.reduce_all(success)))

    def body(i, attack, success):
      new_attack = self._inner_attack(inputs, labels)
      new_success = self._inner_attack.success
      # The first iteration always sets the attack.
      use_new_values = tf.logical_or(tf.equal(i, 0), new_success)
      return (i + 1,
              tf.where(use_new_values, new_attack, attack),
              tf.logical_or(success, new_success))

    _, self._attack, self._success = tf.while_loop(
        cond, body, back_prop=False, parallel_iterations=1,
        loop_vars=[
            tf.constant(0, dtype=tf.int32),
            inputs,
            tf.zeros([tf.shape(inputs)[0]], dtype=tf.bool),
        ])
    self._logits = self._eval_fn(self._attack, mode='final')
    return self._attack 
Example #6
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def test_loop_3_vars():
    graph = tf.Graph()
    with graph.as_default():
        i0 = tf.constant(1)
        j0 = tf.constant(2)
        k0 = tf.constant(4)

        def c(i, j, k): return i < 10

        def b(i, j, k): return [i+1, j * k, k + i]
        r = tf.while_loop(c, b, loop_vars=[i0, j0, k0])

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out) 
Example #7
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 6 votes vote down vote up
def test_loop_2_vars():
    graph = tf.Graph()
    with graph.as_default():
        i0 = tf.constant(0)
        j0 = tf.ones([2, 2])

        def c(i, j): return i < 10

        def b(i, j): return [tf.add(i, 1), j]

        i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0])
        i1 += tf.constant(1337)

        with tf.Session() as sess:
            tf_out = sess.run(i1)

    check_equal(graph, tf_out) 
Example #8
Source File: beam_search_v1.py    From models with Apache License 2.0 6 votes vote down vote up
def search(self, initial_ids, initial_cache):
    """Beam search for sequences with highest scores."""
    state, state_shapes = self._create_initial_state(initial_ids, initial_cache)

    finished_state = tf.while_loop(
        self._continue_search, self._search_step, loop_vars=[state],
        shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
    finished_state = finished_state[0]

    alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
    alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
    finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
    finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
    finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]

    # Account for corner case where there are no finished sequences for a
    # particular batch item. In that case, return alive sequences for that batch
    # item.
    finished_seq = tf.where(
        tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
    finished_scores = tf.where(
        tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
    return finished_seq, finished_scores 
Example #9
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def test_nested_loop_bound():
    graph = tf.Graph()
    with graph.as_default():
        dshape = (2, 10)
        dtype = "float32"
        dname = "data"
        np_data = np.random.uniform(size=dshape).astype(dtype)
        data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
        x = tf.slice(data, [1, 4], [1, 4])
        outer = x + 5.0
        def body(x, y):
            res = tf.cond(tf.less(y, 10), lambda: tf.add(
                10.0, 20.0), lambda: tf.square(10.0))
            def nested_body(nx, ny):
                return nx + 1, res + 2.0
            def nested_cond(nx, ny):
                return tf.less(nx, 15)
            nx = tf.constant(0)
            ny = tf.constant(0.0)
            nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny])
            res = res + nested_res[1]
            z = tf.constant(7)
            res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
            return tf.multiply(res, x * outer), y + 1

        y = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[x, y])
        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})

    check_equal(graph, tf_out, {dname: np_data}) 
Example #10
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def test_vanilla_loop_bound():
    graph = tf.Graph()
    with graph.as_default():
        dshape = (2, 10)
        dtype = "float32"
        dname = "data"
        np_data = np.random.uniform(size=dshape).astype(dtype)
        data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
        x = tf.slice(data, [1, 4], [1, 4])
        outer = x + 5.0
        def body(x, y):
            res = tf.cond(tf.less(y, 10), lambda: tf.add(
                10.0, 20.0), lambda: tf.square(10.0))
            z = tf.constant(7)
            res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
            return tf.multiply(res, x * outer), y + 1

        y = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[x, y])
        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})

    check_equal(graph, tf_out, {dname: np_data}) 
Example #11
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def test_callnode_loop_vars():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.add(tf.constant(0), 1)

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out = sess.run(r)

        check_equal(graph, tf_out) 
Example #12
Source File: test_control_flow.py    From incubator-tvm with Apache License 2.0 5 votes vote down vote up
def test_vanilla_loop():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(0, name="while/constant")

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out = sess.run(r)

        check_equal(graph, tf_out) 
Example #13
Source File: mnist_benchmark.py    From autograph with Apache License 2.0 5 votes vote down vote up
def benchmark_handwritten(self):
    with tf.Graph().as_default():
      ds, opt, hp, w, b = get_data_and_params()
      iterator = ds.make_one_shot_iterator()

      def loop_body(i, unused_previous_loss_t):
        """Manual implementation of training loop."""
        # Call get_next() inside body or else training happens repeatedly on
        # the first minibatch only.
        x, y = iterator.get_next()
        loss_t = loss_fn(x, y, w, b)
        train_op = opt.minimize(loss_t, var_list=(w, b))
        i = tf.cond(tf.equal(i % 100, 0),
                    lambda: tf.Print(i, [i, loss_t], message='Step, loss: '),
                    lambda: i)

        with tf.control_dependencies([train_op]):
          return i + 1, loss_t

      _, final_loss_t = tf.while_loop(
          lambda i, _: i < hp.train_steps,
          loop_body,
          [tf.constant(0), tf.constant(0.0)])

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        def target():
          loss_val = sess.run(final_loss_t)
          assert 0.1 < loss_val < 1, loss_val

        self.time_execution(
            'Handwritten',
            target,
            iter_volume=hp.train_steps,
            iter_unit='training steps') 
Example #14
Source File: attacks.py    From interval-bound-propagation with Apache License 2.0 5 votes vote down vote up
def adapt(self, original_inputs, adversarial_inputs, labels):
    """Runs binary search to find the first misclassified input."""
    batch_size = tf.shape(original_inputs)[0]
    binary_search_iterations = 10

    def cond(i, *_):
      return tf.less(i, binary_search_iterations)

    def get(m):
      m = tf.reshape(m, [batch_size] + [1] * (len(original_inputs.shape) - 1))
      return (adversarial_inputs - original_inputs) * m + original_inputs

    def is_attack_successful(m):
      logits = self._eval_fn(get(m))
      return self._success_fn(self._specification.evaluate(logits))

    def loop_body(i, lower, upper):
      m = (lower + upper) * .5
      success = is_attack_successful(m)
      new_lower = tf.where(success, lower, m)
      new_upper = tf.where(success, m, upper)
      return i + 1, new_lower, new_upper

    lower = tf.zeros(shape=[batch_size])
    upper = tf.ones(shape=[batch_size])
    _, lower, upper = tf.while_loop(
        cond,
        loop_body,
        loop_vars=[tf.constant(0.), lower, upper],
        parallel_iterations=1,
        back_prop=False)
    # If lower is incorrectly classified, pick lower; otherwise pick upper.
    success = is_attack_successful(lower)
    return get(tf.where(success, lower, upper)) 
Example #15
Source File: seq2seq.py    From magenta with Apache License 2.0 5 votes vote down vote up
def _should_cache_variables():
  """Returns True if a default caching device should be set, otherwise False."""
  # Don't set a caching device when running in a loop, since it is possible that
  # train steps could be wrapped in a tf.while_loop. In that scenario caching
  # prevents forward computations in loop iterations from re-reading the
  # updated weights.
  graph = tf.get_default_graph()
  ctxt = graph._get_control_flow_context()  # pylint: disable=protected-access
  in_v1_while_loop = (
      control_flow_util.GetContainingWhileContext(ctxt) is not None)
  return not in_v1_while_loop 
Example #16
Source File: common_layers.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def should_generate_summaries():
  """Is this an appropriate context to generate summaries.

  Returns:
    a boolean
  """
  name_scope = contrib.framework().get_name_scope()
  if name_scope and "while/" in name_scope:
    # Summaries don't work well within tf.while_loop()
    return False
  if tf.get_variable_scope().reuse:
    # Avoid generating separate summaries for different data shards
    return False
  return True 
Example #17
Source File: post_processing.py    From models with Apache License 2.0 4 votes vote down vote up
def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
  """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE).

  Args:
    boxes: a tensor with a shape of [1, anchors, 4].
    iou_threshold: a float representing the threshold for deciding whether boxes
      overlap too much with respect to IOU.
    output_size: an int32 tensor of size [1]. Representing the number of
      selected boxes.
    idx: an integer scalar representing induction variable.

  Returns:
    boxes: updated boxes.
    iou_threshold: pass down iou_threshold to the next iteration.
    output_size: the updated output_size.
    idx: the updated induction variable.
  """
  num_tiles = tf.shape(boxes)[1] // _NMS_TILE_SIZE

  # Iterates over tiles that can possibly suppress the current tile.
  box_slice = tf.slice(boxes, [0, idx * _NMS_TILE_SIZE, 0],
                       [1, _NMS_TILE_SIZE, 4])
  _, box_slice, _, _ = tf.while_loop(
      lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
      _cross_suppression, [boxes, box_slice, iou_threshold,
                           tf.constant(0)])

  # Iterates over the current tile to compute self-suppression.
  iou = batch_iou(box_slice, box_slice)
  mask = tf.expand_dims(
      tf.reshape(tf.range(_NMS_TILE_SIZE), [1, -1]) > tf.reshape(
          tf.range(_NMS_TILE_SIZE), [-1, 1]), 0)
  iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype)
  suppressed_iou, _, _, _ = tf.while_loop(
      lambda _iou, _threshold, loop_condition, _iou_sum: loop_condition,
      _self_suppression,
      [iou, iou_threshold,
       tf.constant(True),
       tf.reduce_sum(iou, [1, 2])])
  suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0
  box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2)

  # Uses box_slice to update the input boxes.
  mask = tf.reshape(
      tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1])
  boxes = tf.tile(tf.expand_dims(box_slice, [1]),
                  [1, num_tiles, 1, 1]) * mask + tf.reshape(
                      boxes, [1, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask)
  boxes = tf.reshape(boxes, [1, -1, 4])

  # Updates output_size.
  output_size += tf.reduce_sum(
      tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1])
  return boxes, iou_threshold, output_size, idx + 1 
Example #18
Source File: preprocessors.py    From text-to-text-transfer-transformer with Apache License 2.0 4 votes vote down vote up
def _span_answer(context, answer_text):
  """Finds start/end indices of answer_text in context after space tokenization.

  If answer_tokens is not a sublist of context_tokens, returns empty string.

  Args:
    context: 0-d string tensor
    answer_text: 0-d string

  Returns:
    A string tensor.
  """
  def space_tok(s):
    """Replace non-word chars with space then split on space."""
    s = tf.strings.regex_replace(s, r'\W', ' ')
    return tf.strings.split(input=[s], sep=' ').values

  def find_subseq(n, h):
    """Finds index of needle subsequence inside haystack.

    Args:
      n: 1-d tensor
      h: 1-d tensor same type as n

    Returns:
      Index of start of n if found found; otherwise -1.
    """
    l_n = tf.size(n)
    l_h = tf.size(h)
    i = tf.constant(0)
    end = l_h - l_n
    # TODO(peterjliu): Replace with craffel@'s more efficient code
    # if necessary: cr/254848350.
    w = tf.while_loop(
        lambda i: tf.logical_and(tf.less(i, end),
                                 tf.reduce_any(tf.not_equal(h[i:i+l_n], n))),
        lambda i: i+1,
        [i])
    return tf.cond(tf.equal(end, w), lambda: -1, lambda: w)

  answer_tokens = space_tok(answer_text)
  context_tokens = space_tok(context)
  start = find_subseq(answer_tokens, context_tokens)
  end = start + tf.size(answer_tokens) - 1
  # Just take the first candidate that matches exactly.
  return tf.cond(tf.equal(start, -1),
                 lambda: tf.constant(''),
                 lambda: tf.strings.format('start: {} end: {}', [start, end])) 
Example #19
Source File: rnn_benchmark.py    From autograph with Apache License 2.0 4 votes vote down vote up
def _benchmark_handwritten_dynamic_rnn(self, batch_size, max_seq_len):

    def my_dynamic_rnn(rnn_cell,
                       input_data,
                       initial_state,
                       sequence_length=None):
      """A handwritten reimplementation of dynamic_rnn."""
      input_data = tf.transpose(input_data, [1, 0, 2])
      outputs = tf.TensorArray(tf.float32, input_data.shape[0])
      if sequence_length is None:
        max_seq_len = input_data.shape[0]
      else:
        max_seq_len = tf.reduce_max(sequence_length)

      def while_body(i, state, outputs):
        new_output, new_state = rnn_cell(input_data[i], state)
        output = tf.where(i < sequence_length, new_output,
                          tf.zeros(new_output.shape))
        state = tf.where(i < sequence_length, new_state, state)
        outputs = outputs.write(i, output)
        return i + 1, state, outputs

      def while_cond(i, unused_state, unused_outputs):
        return i < max_seq_len

      _, state, outputs = tf.while_loop(
          while_cond,
          while_body,
          loop_vars=(tf.constant(0), initial_state, outputs))
      return tf.transpose(outputs.stack(), [1, 0, 2]), state

    with tf.Graph().as_default():
      input_data, sequence_lengths = self._generate_fake_rnn_inputs(
          batch_size=batch_size, max_seq_len=max_seq_len)
      rnn_cell, initial_state = self._create_rnn_cell(batch_size=batch_size)
      graph_output_t = my_dynamic_rnn(rnn_cell, input_data, initial_state,
                                      sequence_lengths)

      with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        def target():
          sess.run(graph_output_t)

        self.time_execution(
            ('Handwritten', batch_size, max_seq_len),
            target,
            iter_volume=batch_size,
            iter_unit='examples',
            extras={
                'max_seq_len': max_seq_len,
                'batch_size': batch_size,
            }) 
Example #20
Source File: attacks.py    From interval-bound-propagation with Apache License 2.0 4 votes vote down vote up
def pgd_attack(loss_fn, input_image, epsilon, num_steps,
               optimizer=UnrolledGradientDescent(),
               project_perturbation=_project_perturbation,
               image_bounds=None, random_init=1.):
  """Projected gradient descent for generating adversarial images.

  Args:
    loss_fn: A callable which takes `input_image` and `label` as arguments, and
      returns the loss, a scalar Tensor, we will be minimized
    input_image: Tensor, a batch of images
    epsilon: float, the L-infinity norm of the maximum allowable perturbation
    num_steps: int, the number of steps of gradient descent
    optimizer: An `UnrolledOptimizer` object
    project_perturbation: A function, which will be used to enforce some
      constraint. It should have the same signature as `_project_perturbation`.
      Note that if you use a custom projection function, you should double-check
      your implementation, since an incorrect implementation will not error,
      and will appear to work fine.
    image_bounds: A pair of floats: minimum and maximum pixel value. If None
      (default), the bounds are assumed to be 0 and 1.
    random_init: Probability of starting from random location rather than
      nominal input image.

  Returns:
    adversarial version of `input_image`, with L-infinity difference less than
      epsilon, which tries to minimize loss_fn.
  """
  image_bounds = image_bounds or (0., 1.)
  random_shape = [tf.shape(input_image)[0]] + [1] * (len(input_image.shape) - 1)
  use_random_init = tf.cast(
      tf.random_uniform(random_shape) < float(random_init), tf.float32)
  init_perturbation = use_random_init * tf.random_uniform(
      tf.shape(input_image), minval=-epsilon, maxval=epsilon)
  init_perturbation = project_perturbation(init_perturbation,
                                           epsilon, input_image, image_bounds)
  init_optim_state = optimizer.init_state([init_perturbation])

  def loop_body(i, perturbation, flat_optim_state):
    """Update perturbation to input image."""
    optim_state = nest.pack_sequence_as(structure=init_optim_state,
                                        flat_sequence=flat_optim_state)
    loss = loss_fn(input_image + perturbation)
    new_perturbation_list, new_optim_state = optimizer.minimize(
        loss, [perturbation], optim_state)
    projected_perturbation = project_perturbation(
        new_perturbation_list[0], epsilon, input_image, image_bounds)
    return i + 1, projected_perturbation, nest.flatten(new_optim_state)

  def cond(i, *_):
    return tf.less(i, num_steps)

  flat_init_optim_state = nest.flatten(init_optim_state)
  _, final_perturbation, _ = tf.while_loop(
      cond,
      loop_body,
      loop_vars=[tf.constant(0.), init_perturbation, flat_init_optim_state],
      parallel_iterations=1,
      back_prop=False)

  adversarial_image = input_image + final_perturbation
  return tf.stop_gradient(adversarial_image) 
Example #21
Source File: modeling.py    From gpt2-ml with Apache License 2.0 4 votes vote down vote up
def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
           do_topk=False):
    """
    V1 version of: sample outputs from a model, and do it all at once
    :param news_config: Configuration used to construct the model
    :param initial_context: [batch_size, seq_length] that we'll start generating with
    :param eos_token: Stop generating if you see this (tf scalar)
    :param min_len: min length of sample
    :param ignore_ids: NEVER GENERATE THESE [vocab_size]
    :return:
    """
    batch_size, _ = get_shape_list(initial_context, expected_rank=2)

    if ignore_ids is None:
        ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool)

    with tf.name_scope('sample_sequence'):
        # Initial call to get cache
        context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config,
                                                 p_for_topp=p_for_topp,
                                                 do_topk=do_topk)
        ctx = context_output['tokens']
        cache = context_output['cache']
        probs = context_output['probs']

        def body(ctx, cache, probs):
            """ for whatever reason this didn't work when I ran it on more than one at once... ugh."""
            next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config,
                                       batch_size=batch_size, p_for_topp=p_for_topp, cache=cache,
                                       do_topk=do_topk)

            # Update everything
            new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2)
            new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1)
            new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1)
            return [new_ids, new_cache, new_probs]

        def cond(ctx, cache, probs):
            # ctx = tf.Print(ctx,[tf.shape(ctx)])
            is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1))
            is_len = tf.greater(get_shape_list(ctx)[1], min_len)
            return tf.logical_not(tf.logical_and(is_eos, is_len))

        tokens, cache, probs = tf.while_loop(
            cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1],
            loop_vars=[ctx, cache, probs],
            shape_invariants=[tf.TensorShape([batch_size, None]),
                              tf.TensorShape(
                                  [batch_size, news_config.num_hidden_layers, 2,
                                   news_config.num_attention_heads,
                                   None, news_config.hidden_size // news_config.num_attention_heads]),
                              tf.TensorShape([batch_size, None]),
                              ],
            back_prop=False,
        )
    return tokens, probs 
Example #22
Source File: visualization.py    From tensor2tensor with Apache License 2.0 4 votes vote down vote up
def build_model(hparams_set, model_name, data_dir, problem_name, beam_size=1):
  """Build the graph required to fetch the attention weights.

  Args:
    hparams_set: HParams set to build the model with.
    model_name: Name of model.
    data_dir: Path to directory containing training data.
    problem_name: Name of problem.
    beam_size: (Optional) Number of beams to use when decoding a translation.
        If set to 1 (default) then greedy decoding is used.

  Returns:
    Tuple of (
        inputs: Input placeholder to feed in ids to be translated.
        targets: Targets placeholder to feed to translation when fetching
            attention weights.
        samples: Tensor representing the ids of the translation.
        att_mats: Tensors representing the attention weights.
    )
  """
  hparams = trainer_lib.create_hparams(
      hparams_set, data_dir=data_dir, problem_name=problem_name)
  translate_model = registry.model(model_name)(
      hparams, tf.estimator.ModeKeys.EVAL)

  inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs")
  targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets")
  translate_model({
      "inputs": inputs,
      "targets": targets,
  })

  # Must be called after building the training graph, so that the dict will
  # have been filled with the attention tensors. BUT before creating the
  # inference graph otherwise the dict will be filled with tensors from
  # inside a tf.while_loop from decoding and are marked unfetchable.
  att_mats = get_att_mats(translate_model)

  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    samples = translate_model.infer({
        "inputs": inputs,
    }, beam_size=beam_size)["outputs"]

  return inputs, targets, samples, att_mats