Python tensorflow.compat.v1.assert_equal() Examples

The following are 19 code examples of tensorflow.compat.v1.assert_equal(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: anchor_generator.py    From models with Apache License 2.0 6 votes vote down vote up
def _assert_correct_number_of_anchors(self, anchors_list,
                                        feature_map_shape_list):
    """Assert that correct number of anchors was generated.

    Args:
      anchors_list: A list of box_list.BoxList object holding anchors generated.
      feature_map_shape_list: list of (height, width) pairs in the format
        [(height_0, width_0), (height_1, width_1), ...] that the generated
        anchors must align with.
    Returns:
      Op that raises InvalidArgumentError if the number of anchors does not
        match the number of expected anchors.
    """
    expected_num_anchors = 0
    actual_num_anchors = 0
    for num_anchors_per_location, feature_map_shape, anchors in zip(
        self.num_anchors_per_location(), feature_map_shape_list, anchors_list):
      expected_num_anchors += (num_anchors_per_location
                               * feature_map_shape[0]
                               * feature_map_shape[1])
      actual_num_anchors += anchors.num_boxes()
    return tf.assert_equal(expected_num_anchors, actual_num_anchors) 
Example #2
Source File: transformer_memory_test.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def testReset(self):
    batch_size = 2
    key_depth = 3
    val_depth = 5
    memory_size = 4
    memory = transformer_memory.TransformerMemory(
        batch_size, key_depth, val_depth, memory_size)
    vals = tf.random_uniform([batch_size, memory_size, val_depth], minval=1.0)
    logits = tf.random_uniform([batch_size, memory_size], minval=1.0)
    update_op = memory.set(vals, logits)
    reset_op = memory.reset([1])
    mem_vals, mem_logits = memory.get()
    assert_op1 = tf.assert_equal(mem_vals[0], vals[0])
    assert_op2 = tf.assert_equal(mem_logits[0], logits[0])
    with tf.control_dependencies([assert_op1, assert_op2]):
      all_zero1 = tf.reduce_sum(tf.abs(mem_vals[1]))
      all_zero2 = tf.reduce_sum(tf.abs(mem_logits[1]))
    with self.test_session() as session:
      session.run(tf.global_variables_initializer())
      session.run(update_op)
      session.run(reset_op)
      zero1, zero2 = session.run([all_zero1, all_zero2])
    self.assertAllEqual(0, zero1)
    self.assertAllEqual(0, zero2) 
Example #3
Source File: sequence_ops.py    From trfl with Apache License 2.0 6 votes vote down vote up
def _reverse_seq(sequence, sequence_lengths=None):
  """Reverse sequence along dim 0.

  Args:
    sequence: Tensor of shape [T, B, ...].
    sequence_lengths: (optional) tensor of shape [B]. If `None`, only reverse
      along dim 0.

  Returns:
    Tensor of same shape as sequence with dim 0 reversed up to sequence_lengths.
  """
  if sequence_lengths is None:
    return tf.reverse(sequence, [0])

  sequence_lengths = tf.convert_to_tensor(sequence_lengths)
  with tf.control_dependencies(
      [tf.assert_equal(sequence.shape[1], sequence_lengths.shape[0])]):
    return tf.reverse_sequence(
        sequence, sequence_lengths, seq_axis=0, batch_axis=1) 
Example #4
Source File: lstm_models.py    From magenta with Apache License 2.0 6 votes vote down vote up
def _merge_decode_results(self, decode_results):
    """Merge in the output dimension."""
    output_axis = -1
    assert decode_results
    zipped_results = lstm_utils.LstmDecodeResults(*list(zip(*decode_results)))
    with tf.control_dependencies([
        tf.assert_equal(
            zipped_results.final_sequence_lengths, self.hparams.max_seq_len,
            message='Variable length not supported by '
                    'MultiOutCategoricalLstmDecoder.')]):
      if zipped_results.final_state[0] is None:
        final_state = None
      else:
        final_state = tf.nest.map_structure(
            lambda x: tf.concat(x, axis=output_axis),
            zipped_results.final_state)

      return lstm_utils.LstmDecodeResults(
          rnn_output=tf.concat(zipped_results.rnn_output, axis=output_axis),
          rnn_input=tf.concat(zipped_results.rnn_input, axis=output_axis),
          samples=tf.concat(zipped_results.samples, axis=output_axis),
          final_state=final_state,
          final_sequence_lengths=zipped_results.final_sequence_lengths[0]) 
Example #5
Source File: shape_utils.py    From Object_Detection_Tracking with Apache License 2.0 5 votes vote down vote up
def assert_shape_equal(shape_a, shape_b):
  """Asserts that shape_a and shape_b are equal.

  If the shapes are static, raises a ValueError when the shapes
  mismatch.

  If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
  mismatch.

  Args:
    shape_a: a list containing shape of the first tensor.
    shape_b: a list containing shape of the second tensor.

  Returns:
    Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
    when the shapes are dynamic.

  Raises:
    ValueError: When shapes are both static and unequal.
  """
  if (all(isinstance(dim, int) for dim in shape_a) and
      all(isinstance(dim, int) for dim in shape_b)):
    if shape_a != shape_b:
      raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
    else: return tf.no_op()
  else:
    return tf.assert_equal(shape_a, shape_b) 
Example #6
Source File: shape_utils.py    From models with Apache License 2.0 5 votes vote down vote up
def expand_first_dimension(inputs, dims):
  """Expands `K-d` tensor along first dimension to be a `(K+n-1)-d` tensor.

  Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape
  [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)].

  Example:
  `inputs` is a tensor with shape [50, 20, 20, 3].
  new_tensor = expand_first_dimension(inputs, [10, 5]).
  new_tensor.shape -> [10, 5, 20, 20, 3].

  Args:
    inputs: a tensor with shape [D0, D1, ..., D(K-1)].
    dims: List with new dimensions to expand first axis into. The length of
      `dims` is typically 2 or larger.

  Returns:
    a tensor with shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)].
  """
  inputs_shape = combined_static_and_dynamic_shape(inputs)
  expanded_shape = tf.stack(dims + inputs_shape[1:])

  # Verify that it is possible to expand the first axis of inputs.
  assert_op = tf.assert_equal(
      inputs_shape[0], tf.reduce_prod(tf.stack(dims)),
      message=('First dimension of `inputs` cannot be expanded into provided '
               '`dims`'))

  with tf.control_dependencies([assert_op]):
    inputs_reshaped = tf.reshape(inputs, expanded_shape)

  return inputs_reshaped 
Example #7
Source File: shape_utils.py    From models with Apache License 2.0 5 votes vote down vote up
def assert_shape_equal_along_first_dimension(shape_a, shape_b):
  """Asserts that shape_a and shape_b are the same along the 0th-dimension.

  If the shapes are static, raises a ValueError when the shapes
  mismatch.

  If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
  mismatch.

  Args:
    shape_a: a list containing shape of the first tensor.
    shape_b: a list containing shape of the second tensor.

  Returns:
    Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
    when the shapes are dynamic.

  Raises:
    ValueError: When shapes are both static and unequal.
  """
  if isinstance(shape_a[0], int) and isinstance(shape_b[0], int):
    if shape_a[0] != shape_b[0]:
      raise ValueError('Unequal first dimension {}, {}'.format(
          shape_a[0], shape_b[0]))
    else: return tf.no_op()
  else:
    return tf.assert_equal(shape_a[0], shape_b[0]) 
Example #8
Source File: shape_utils.py    From models with Apache License 2.0 5 votes vote down vote up
def assert_shape_equal(shape_a, shape_b):
  """Asserts that shape_a and shape_b are equal.

  If the shapes are static, raises a ValueError when the shapes
  mismatch.

  If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes
  mismatch.

  Args:
    shape_a: a list containing shape of the first tensor.
    shape_b: a list containing shape of the second tensor.

  Returns:
    Either a tf.no_op() when shapes are all static and a tf.assert_equal() op
    when the shapes are dynamic.

  Raises:
    ValueError: When shapes are both static and unequal.
  """
  if (all(isinstance(dim, int) for dim in shape_a) and
      all(isinstance(dim, int) for dim in shape_b)):
    if shape_a != shape_b:
      raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
    else: return tf.no_op()
  else:
    return tf.assert_equal(shape_a, shape_b) 
Example #9
Source File: image_transformations.py    From tensor2robot with Apache License 2.0 5 votes vote down vote up
def CenterCropImages(images, input_shape,
                     target_shape):
  """Take a central crop of given size from a list of images.

  Args:
    images: List of tensors of shape [batch_size, h, w, c].
    input_shape: Shape [h, w, c] of the input images.
    target_shape: Shape [h, w] of the cropped output.

  Returns:
    crops: List of cropped tensors of shape [batch_size] + target_shape.
  """
  if len(input_shape) != 3:
    raise ValueError(
        'The input shape has to be of the form (height, width, channels) '
        'but has len {}'.format(len(input_shape)))
  if len(target_shape) != 2:
    raise ValueError('The target shape has to be of the form (height, width) '
                     'but has len {}'.format(len(target_shape)))
  if input_shape[0] == target_shape[0] and input_shape[1] == target_shape[1]:
    return [image for image in images]

  # Assert all images have the same shape.
  assert_ops = []
  for image in images:
    assert_ops.append(
        tf.assert_equal(
            input_shape[:2],
            tf.shape(image)[1:3],
            message=('All images must have same width and height'
                     'for CenterCropImages.')))
  offset_y = int(input_shape[0] - target_shape[0]) // 2
  offset_x = int(input_shape[1] - target_shape[1]) // 2
  with tf.control_dependencies(assert_ops):
    crops = [
        tf.image.crop_to_bounding_box(image, offset_y, offset_x,
                                      target_shape[0], target_shape[1])
        for image in images
    ]
  return crops 
Example #10
Source File: data.py    From magenta with Apache License 2.0 5 votes vote down vote up
def tflite_compat_mel(wav_audio, hparams):
  """EXPERIMENTAL: Log mel spec with ops that can be made TFLite compatible."""
  samples, decoded_sample_rate = tf.audio.decode_wav(
      wav_audio, desired_channels=1)
  samples = tf.squeeze(samples, axis=1)
  # Ensure that we decoded the samples at the expected sample rate.
  with tf.control_dependencies(
      [tf.assert_equal(decoded_sample_rate, hparams.sample_rate)]):
    return tflite_compat_mel_from_samples(samples, hparams) 
Example #11
Source File: common_attention_test.py    From tensor2tensor with Apache License 2.0 5 votes vote down vote up
def testAddTimingSignalsGivenPositionsEquivalent(self):
    x = tf.zeros([1, 10, 128], dtype=tf.float32)
    positions = tf.expand_dims(tf.range(0, 10, dtype=tf.float32), axis=0)
    # The method add_timing_signal_1d_given_position could be replaced by
    # add_timing_signals_given_positions:
    tf.assert_equal(
        common_attention.add_timing_signal_1d_given_position(x, positions),
        common_attention.add_timing_signals_given_positions(x, [positions])) 
Example #12
Source File: beam_search_test.py    From tensor2tensor with Apache License 2.0 4 votes vote down vote up
def testStates(self):
    batch_size = 1
    beam_size = 1
    vocab_size = 2
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]])

    expected_states = tf.constant([[[0.]], [[1.]]])

    def symbols_to_logits(ids, _, states):
      pos = tf.shape(ids)[1] - 1
      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[pos])]):
        logits = tf.to_float(tf.log(probabilities[pos, :]))

      states["state"] += 1
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        states=states)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message) 
Example #13
Source File: lib_graph.py    From magenta with Apache License 2.0 4 votes vote down vote up
def compute_loss(self, unreduced_loss):
    """Computes scaled loss based on mask out size."""
    # construct mask to identify zero padding that was introduced to
    # make the batch rectangular
    batch_duration = tf.shape(self.pianorolls)[1]
    indices = tf.to_float(tf.range(batch_duration))
    pad_mask = tf.to_float(
        indices[None, :, None, None] < self.lengths[:, None, None, None])

    # construct mask and its complement, respecting pad mask
    mask = pad_mask * self.masks
    unmask = pad_mask * (1. - self.masks)

    # Compute numbers of variables
    # #timesteps * #variables per timestep
    variable_axis = 3 if self.hparams.use_softmax_loss else 2
    dd = (
        self.lengths[:, None, None, None] * tf.to_float(
            tf.shape(self.pianorolls)[variable_axis]))
    reduced_dd = tf.reduce_sum(dd)

    # Compute numbers of variables to be predicted/conditioned on
    mask_size = tf.reduce_sum(mask, axis=[1, variable_axis], keep_dims=True)
    unmask_size = tf.reduce_sum(unmask, axis=[1, variable_axis], keep_dims=True)

    unreduced_loss *= pad_mask
    if self.hparams.rescale_loss:
      unreduced_loss *= dd / mask_size

    # Compute average loss over entire set of variables
    self.loss_total = tf.reduce_sum(unreduced_loss) / reduced_dd

    # Compute separate losses for masked/unmasked variables
    # NOTE: indexing the pitch dimension with 0 because the mask is constant
    # across pitch. Except in the sigmoid case, but then the pitch dimension
    # will have been reduced over.
    self.reduced_mask_size = tf.reduce_sum(mask_size[:, :, 0, :])
    self.reduced_unmask_size = tf.reduce_sum(unmask_size[:, :, 0, :])

    assert_partition_op = tf.group(
        tf.assert_equal(tf.reduce_sum(mask * unmask), 0.),
        tf.assert_equal(self.reduced_mask_size + self.reduced_unmask_size,
                        reduced_dd))
    with tf.control_dependencies([assert_partition_op]):
      self.loss_mask = (
          tf.reduce_sum(mask * unreduced_loss) / self.reduced_mask_size)
      self.loss_unmask = (
          tf.reduce_sum(unmask * unreduced_loss) / self.reduced_unmask_size)

    # Check which loss to use as objective function.
    self.loss = (
        self.loss_mask if self.hparams.optimize_mask_only else self.loss_total) 
Example #14
Source File: tensor_buffer.py    From privacy with Apache License 2.0 4 votes vote down vote up
def append(self, value):
    """Appends a new tensor to the end of the buffer.

    Args:
      value: The tensor to append. Must match the shape specified in the
        initializer.

    Returns:
      An op appending the new tensor to the end of the buffer.
    """

    def _double_capacity():
      """Doubles the capacity of the current tensor buffer."""
      padding = tf.zeros_like(self._buffer, self._buffer.dtype)
      new_buffer = tf.concat([self._buffer, padding], axis=0)
      if tf.executing_eagerly():
        with tf.variable_scope(self._name, reuse=True):
          self._buffer = tf.get_variable(
              name='buffer',
              dtype=self._dtype,
              initializer=new_buffer,
              trainable=False)
          return self._buffer, tf.assign(
              self._capacity, tf.multiply(self._capacity, 2))
      else:
        return tf.assign(
            self._buffer, new_buffer,
            validate_shape=False), tf.assign(
                self._capacity, tf.multiply(self._capacity, 2))

    update_buffer, update_capacity = tf.cond(
        pred=tf.equal(self._current_size, self._capacity),
        true_fn=_double_capacity,
        false_fn=lambda: (self._buffer, self._capacity))

    with tf.control_dependencies([update_buffer, update_capacity]):
      with tf.control_dependencies([
          tf.assert_less(
              self._current_size,
              self._capacity,
              message='Appending past end of TensorBuffer.'),
          tf.assert_equal(
              tf.shape(input=value),
              tf.shape(input=self._buffer)[1:],
              message='Appending value of inconsistent shape.')
      ]):
        with tf.control_dependencies(
            [tf.assign(self._buffer[self._current_size, :], value)]):
          return tf.assign_add(self._current_size, 1) 
Example #15
Source File: attention.py    From language with Apache License 2.0 4 votes vote down vote up
def _prepare_memory(
    memory,
    memory_sequence_length,
    mask,
    check_inner_dims_defined):
  """Convert to tensor and possibly mask `memory`.

  Args:
    memory: `Tensor`, shaped `[batch_size, max_time, ...]`.
    memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`.
    mask: To mask out some of the elements.
    check_inner_dims_defined: Python boolean.  If `True`, the `memory`
      argument's shape is checked to ensure all but the two outermost
      dimensions are fully defined.

  Returns:
    A (possibly masked), checked, new `memory`.

  Raises:
    ValueError: If `check_inner_dims_defined` is `True` and not
      `memory.shape[2:].is_fully_defined()`.
  """
  memory = tf.contrib.framework.nest.map_structure(
      lambda m: tf.convert_to_tensor(m, name="memory"), memory)
  if memory_sequence_length is not None:
    memory_sequence_length = tf.convert_to_tensor(
        memory_sequence_length, name="memory_sequence_length")
  if check_inner_dims_defined:
    def _check_dims(m):
      if not m.get_shape()[2:].is_fully_defined():
        raise ValueError("Expected memory %s to have fully defined inner dims, "
                         "but saw shape: %s" % (m.name, m.get_shape()))

    tf.contrib.framework.nest.map_structure(_check_dims, memory)

  seq_len_mask = tf.cast(mask,
                         tf.contrib.framework.nest.flatten(memory)[0].dtype)
  seq_len_batch_size = (dimension_value(mask.shape[0]) or tf.shape(mask)[0])

  def _maybe_mask(m, seq_len_mask):
    """Mask the sequence with m."""
    rank = m.get_shape().ndims
    rank = rank if rank is not None else tf.rank(m)
    extra_ones = tf.ones(rank - 2, dtype=tf.int32)
    m_batch_size = dimension_value(m.shape[0]) or tf.shape(m)[0]
    with tf.control_dependencies(
        [tf.assert_equal(seq_len_batch_size, m_batch_size, message="batch")]):
      seq_len_mask = tf.reshape(
          seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0))
      return m * seq_len_mask

  return tf.contrib.framework.nest.map_structure(
      lambda m: _maybe_mask(m, seq_len_mask), memory) 
Example #16
Source File: preprocessors.py    From text-to-text-transfer-transformer with Apache License 2.0 4 votes vote down vote up
def _wsc_inputs(x):
  """Given an example from SuperGLUE WSC, compute the 'inputs' value.

  The output will look like a fill in the blank with the pronoun blanked out.
  For example, the text
    'Mitchell asked Tom if he could lend some money.'
  would be transformed to
    'Mitchell asked Tom if X could lend some money.'

  Args:
    x: A dict that is an example from the WSC task of SuperGLUE.

  Returns:
    A scalar string tensor.
  """
  words = tf.strings.split([x['text']], sep=' ').values

  # We would need some special logic to handle the case where the pronoun is the
  # first or last word in the text. None of the examples in WSC seem to have
  # this, so we are ignoring these cases.
  with tf.control_dependencies([
      tf.assert_greater(x['span2_index'], 0),
      tf.assert_less(x['span2_index'], tf.size(words)),
  ]):
    pronoun_index = tf.identity(x['span2_index'])

  def create_input():
    with tf.control_dependencies(
        [tf.assert_equal(words[pronoun_index], x['span2_text'])]):
      return tf.strings.join(
          [
              tf.strings.reduce_join(words[:pronoun_index], separator=' '),
              'X',
              tf.strings.reduce_join(
                  words[pronoun_index + 1:], separator=' '),
          ],
          separator=' ',
      )

  # Handle some special cases.
  return tf.case(
      {
          # The issue here is that the pronoun is 'him,"' in the text.
          tf.equal(
              x['text'],
              'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. \"Good for him,\" he said. '
          ):
              lambda:
              'The boy continued to whip the pony , and eventually the pony threw him over. John laughed out quite loud. "Good for X ," he said.',
          # Using the span2_index, we get 'use' instead of 'it'.
          tf.equal(
              x['text'],
              'When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use it , but really for now, what more could they wish for?'
          ):
              lambda:
              'When they had eventually calmed down a bit , and had gotten home, Mr. Farley put the magic pebble in an iron safe . Some day they might want to use X , but really for now, what more could they wish for?'
      },
      default=create_input,
      exclusive=True) 
Example #17
Source File: lstm_models.py    From magenta with Apache License 2.0 4 votes vote down vote up
def encode(self, sequence, sequence_length):
    """Hierarchically encodes the input sequences, returning a single embedding.

    Each sequence should be padded per-segment. For example, a sequence with
    three segments [1, 2, 3], [4, 5], [6, 7, 8 ,9] and a `max_seq_len` of 12
    should be input as `sequence = [1, 2, 3, 0, 4, 5, 0, 0, 6, 7, 8, 9]` with
    `sequence_length = [3, 2, 4]`.

    Args:
      sequence: A batch of (padded) sequences, sized
        `[batch_size, max_seq_len, input_depth]`.
      sequence_length: A batch of sequence lengths. May be sized
        `[batch_size, level_lengths[0]]` or `[batch_size]`. If the latter,
        each length must either equal `max_seq_len` or 0. In this case, the
        segment lengths are assumed to be constant and the total length will be
        evenly divided amongst the segments.

    Returns:
      embedding: A batch of embeddings, sized `[batch_size, N]`.
    """
    batch_size = int(sequence.shape[0])
    sequence_length = lstm_utils.maybe_split_sequence_lengths(
        sequence_length, np.prod(self._level_lengths[1:]),
        self._total_length)

    for level, (num_splits, h_encoder) in enumerate(
        self._hierarchical_encoders):
      split_seqs = tf.split(sequence, num_splits, axis=1)
      # In the first level, we use the input `sequence_length`. After that,
      # we use the full embedding sequences.
      if level:
        sequence_length = tf.fill(
            [batch_size, num_splits], split_seqs[0].shape[1])
      split_lengths = tf.unstack(sequence_length, axis=1)
      embeddings = [
          h_encoder.encode(s, l) for s, l in zip(split_seqs, split_lengths)]
      sequence = tf.stack(embeddings, axis=1)

    with tf.control_dependencies([tf.assert_equal(tf.shape(sequence)[1], 1)]):
      return sequence[:, 0]


# DECODERS 
Example #18
Source File: beam_search_test.py    From tensor2tensor with Apache License 2.0 4 votes vote down vote up
def testTPUBeam(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    # The top beam is always selected so we should see the top beam's state
    # at each position, which is the one thats getting 3 added to it each step.
    expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]])

    def symbols_to_logits(_, i, states):
      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[i])]):
        logits = tf.to_float(tf.log(probabilities[i, :]))

      states["state"] += tf.constant([[3.], [7.]])
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        3.5,
        eos_id=1,
        states=states,
        use_tpu=True)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message)
    self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]]], final_ids) 
Example #19
Source File: beam_search_test.py    From tensor2tensor with Apache License 2.0 4 votes vote down vote up
def testStateBeamTwo(self):
    batch_size = 1
    beam_size = 2
    vocab_size = 3
    decode_length = 3

    initial_ids = tf.constant([0] * batch_size)  # GO
    probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],
                                 [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]],
                                 [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]])

    # The top beam is always selected so we should see the top beam's state
    # at each position, which is the one thats getting 3 added to it each step.
    expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]])

    def symbols_to_logits(ids, _, states):
      pos = tf.shape(ids)[1] - 1

      # We have to assert the values of state inline here since we can't fetch
      # them out of the loop!
      with tf.control_dependencies(
          [tf.assert_equal(states["state"], expected_states[pos])]):
        logits = tf.to_float(tf.log(probabilities[pos, :]))

      states["state"] += tf.constant([[3.], [7.]])
      return logits, states

    states = {
        "state": tf.zeros((batch_size, 1)),
    }
    states["state"] = tf.placeholder_with_default(
        states["state"], shape=(None, 1))

    final_ids, _, _ = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        beam_size,
        decode_length,
        vocab_size,
        0.0,
        eos_id=1,
        states=states)

    with self.test_session() as sess:
      # Catch and fail so that the testing framework doesn't think it's an error
      try:
        sess.run(final_ids)
      except tf.errors.InvalidArgumentError as e:
        raise AssertionError(e.message)