Python tensorflow.compat.v1.batch_gather() Examples

The following are 7 code examples of tensorflow.compat.v1.batch_gather(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: modeling.py    From gpt2-ml with Apache License 2.0 5 votes vote down vote up
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
    """
    Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)
        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        k_expanded = k if isinstance(k, int) else k[:, None]
        exclude_mask = tf.range(vocab_size)[None] >= k_expanded

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

    return {
        'probs': probs,
        'sample': sample,
    } 
Example #2
Source File: beam_search.py    From tensor2tensor with Apache License 2.0 4 votes vote down vote up
def fast_tpu_gather(params, indices, name=None):
  """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
  with tf.name_scope(name):
    dtype = params.dtype

    def _gather(params, indices):
      """Fast gather using one_hot and batch matmul."""
      if dtype != tf.float32:
        params = tf.to_float(params)
      shape = common_layers.shape_list(params)
      indices_shape = common_layers.shape_list(indices)
      ndims = params.shape.ndims
      # Adjust the shape of params to match one-hot indices, which is the
      # requirement of Batch MatMul.
      if ndims == 2:
        params = tf.expand_dims(params, axis=-1)
      if ndims > 3:
        params = tf.reshape(params, [shape[0], shape[1], -1])
      gather_result = tf.matmul(
          tf.one_hot(indices, shape[1], dtype=params.dtype), params)
      if ndims == 2:
        gather_result = tf.squeeze(gather_result, axis=-1)
      if ndims > 3:
        shape[1] = indices_shape[1]
        gather_result = tf.reshape(gather_result, shape)
      if dtype != tf.float32:
        gather_result = tf.cast(gather_result, dtype)
      return gather_result

    # If the dtype is int, use the gather instead of one_hot matmul to avoid
    # precision loss. The max int value can be represented by bfloat16 in MXU is
    # 256, which is smaller than the possible id values. Encoding/decoding can
    # potentially used to make it work, but the benenfit is small right now.
    if dtype.is_integer:
      gather_result = tf.batch_gather(params, indices)
    else:
      gather_result = _gather(params, indices)

    return gather_result 
Example #3
Source File: modeling.py    From gpt2-ml with Apache License 2.0 4 votes vote down vote up
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
    """
    Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)

        if isinstance(p, float) and p > 0.999999:
            # Don't do top-p sampling in this case
            print("Top-p sampling DISABLED", flush=True)
            return {
                'probs': probs,
                'sample': tf.random.categorical(
                    logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                    num_samples=num_samples, dtype=tf.int32),
            }

        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')
        cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        p_expanded = p if isinstance(p, float) else p[:, None]
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

        # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
        # unperm_indices = tf.argsort(indices, direction='ASCENDING')
        # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
        # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
        # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)

    return {
        'probs': probs,
        'sample': sample,
    } 
Example #4
Source File: modeling.py    From gpt2-ml with Apache License 2.0 4 votes vote down vote up
def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
    """
    Helper function that samples from grover for a single step
    :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
    :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
    :param news_config: config for the GroverModel
    :param batch_size: batch size to use
    :param p_for_topp: top-p or top-k threshold
    :param cache: [batch_size, news_config.num_hidden_layers, 2,
                   news_config.num_attention_heads, n_ctx_a,
                   news_config.hidden_size // news_config.num_attention_heads] OR, None
    :return: new_tokens, size [batch_size]
             new_probs, also size [batch_size]
             new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
                   news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
    """
    model = GroverModel(
        config=news_config,
        is_training=False,
        input_ids=tokens,
        reuse=tf.AUTO_REUSE,
        scope='newslm',
        chop_off_last_token=False,
        do_cache=True,
        cache=cache,
    )

    # Extract the FINAL SEQ LENGTH
    batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
    next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]

    if do_topk:
        sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
    else:
        sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)

    new_tokens = tf.squeeze(sample_info['sample'], 1)
    new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
    return {
        'new_tokens': new_tokens,
        'new_probs': new_probs,
        'new_cache': model.new_kvs,
    } 
Example #5
Source File: loss.py    From interval-bound-propagation with Apache License 2.0 4 votes vote down vote up
def _build_verified_loss(self, labels):
    """Build verified loss using an upper bound on specification."""
    if not self._specification:
      self._verified_loss = tf.constant(0.)
      self._interval_bounds_accuracy = tf.constant(0.)
      return
    # Interval bounds.
    bounds = self._get_specification_bounds()
    # Select specifications.
    if self._interval_bounds_loss_mode == 'all':
      pass  # Keep bounds the way it is.
    elif self._interval_bounds_loss_mode == 'most':
      bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
    elif self._interval_bounds_loss_mode == 'random':
      idx = tf.random.uniform(
          [tf.shape(bounds)[0], self._interval_bounds_loss_n],
          0, tf.shape(bounds)[1], dtype=tf.int32)
      bounds = tf.batch_gather(bounds, idx)
    else:
      assert self._interval_bounds_loss_mode == 'least'
      # This picks the least violated contraint.
      mask = tf.cast(bounds < 0., tf.float32)
      smallest_violation = tf.reduce_min(
          bounds + mask * _BIG_NUMBER, axis=1, keepdims=True)
      has_violations = tf.less(
          tf.reduce_sum(mask, axis=1, keepdims=True) + .5,
          tf.cast(tf.shape(bounds)[1], tf.float32))
      largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
      bounds = tf.where(has_violations, smallest_violation, largest_bounds)

    if self._interval_bounds_loss_type == 'xent':
      v = tf.concat(
          [bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
          axis=1)
      l = tf.concat(
          [tf.zeros_like(bounds),
           tf.ones([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
          axis=1)
      self._verified_loss = tf.reduce_mean(
          tf.nn.softmax_cross_entropy_with_logits_v2(
              labels=tf.stop_gradient(l), logits=v))
    elif self._interval_bounds_loss_type == 'softplus':
      self._verified_loss = tf.reduce_mean(
          tf.nn.softplus(bounds + self._interval_bounds_hinge_margin))
    else:
      assert self._interval_bounds_loss_type == 'hinge'
      self._verified_loss = tf.reduce_mean(
          tf.maximum(bounds, -self._interval_bounds_hinge_margin)) 
Example #6
Source File: slate_decomp_q_agent.py    From recsim with Apache License 2.0 4 votes vote down vote up
def compute_target_topk_q(reward, gamma, next_actions, next_q_values,
                          next_states, terminals):
  """Computes the optimal target Q value with the greedy algorithm.

  This algorithm corresponds to the method "TT" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
  slate_size = next_actions.get_shape().as_list()[1]
  scores, score_no_click = _get_unnormalized_scores(next_states)

  # Choose the documents with top affinity_scores * Q values to fill a slate and
  # treat it as if it is the optimal slate.
  unnormalized_next_q_target = next_q_values * scores
  _, topk_optimal_slate = tf.math.top_k(
      unnormalized_next_q_target, k=slate_size)

  # Get the expected Q-value of the slate containing top-K items.
  # [batch_size, slate_size]
  next_q_values_selected = tf.batch_gather(
      next_q_values, tf.cast(topk_optimal_slate, dtype=tf.int32))

  # Get normalized affinity scores on the slate.
  # [batch_size, slate_size]
  scores_selected = tf.batch_gather(scores,
                                    tf.cast(topk_optimal_slate, dtype=tf.int32))

  next_q_target_topk = tf.reduce_sum(
      input_tensor=next_q_values_selected * scores_selected, axis=1) / (
          tf.reduce_sum(input_tensor=scores_selected, axis=1) + score_no_click)

  return reward + gamma * next_q_target_topk * (
      1. - tf.cast(terminals, tf.float32)) 
Example #7
Source File: slate_decomp_q_agent.py    From recsim with Apache License 2.0 4 votes vote down vote up
def _build_train_op(self):
    """Builds a training op.

    Returns:
      An op performing one step of training from replay data.
    """
    # click_indicator: [B, S]
    # q_values: [B, A]
    # actions: [B, S]
    # slate_q_values: [B, S]
    # replay_click_q: [B]
    click_indicator = self._replay.rewards[:, :, self._click_response_index]
    slate_q_values = tf.batch_gather(
        self._replay_net_outputs.q_values,
        tf.cast(self._replay.actions, dtype=tf.int32))
    # Only get the Q from the clicked document.
    replay_click_q = tf.reduce_sum(
        input_tensor=slate_q_values * click_indicator,
        axis=1,
        name='replay_click_q')

    target = tf.stop_gradient(self._build_target_q_op())

    clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
    clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1)
    # clicked_indices is a vector and tf.gather selects the batch dimension.
    q_clicked = tf.gather(replay_click_q, clicked_indices)
    target_clicked = tf.gather(target, clicked_indices)

    def get_train_op():
      loss = tf.reduce_mean(input_tensor=tf.square(q_clicked - target_clicked))
      if self.summary_writer is not None:
        with tf.variable_scope('Losses'):
          tf.summary.scalar('Loss', loss)

      return loss

    loss = tf.cond(
        pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0),
        true_fn=get_train_op,
        false_fn=lambda: tf.constant(0.),
        name='')

    return self.optimizer.minimize(loss)