Python tensorflow.gather() Examples

The following are 30 code examples of tensorflow.gather(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: box_list_ops.py    From object_detector_app with MIT License 6 votes vote down vote up
def prune_small_boxes(boxlist, min_side, scope=None):
  """Prunes small boxes in the boxlist which have a side smaller than min_side.

  Args:
    boxlist: BoxList holding N boxes.
    min_side: Minimum width AND height of box to survive pruning.
    scope: name scope.

  Returns:
    A pruned boxlist.
  """
  with tf.name_scope(scope, 'PruneSmallBoxes'):
    height, width = height_width(boxlist)
    is_valid = tf.logical_and(tf.greater_equal(width, min_side),
                              tf.greater_equal(height, min_side))
    return gather(boxlist, tf.reshape(tf.where(is_valid), [-1])) 
Example #2
Source File: neural_gpu.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def memory_run(step, nmaps, mem_size, batch_size, vocab_size,
               global_step, do_training, update_mem, decay_factor, num_gpus,
               target_emb_weights, output_w, gpu_targets_tn, it):
  """Run memory."""
  q = step[:, 0, it, :]
  mlabels = gpu_targets_tn[:, it, 0]
  res, mask, mem_loss = memory_call(
      q, mlabels, nmaps, mem_size, vocab_size, num_gpus, update_mem)
  res = tf.gather(target_emb_weights, res) * tf.expand_dims(mask[:, 0], 1)

  # Mix gold and original in the first steps, 20% later.
  gold = tf.nn.dropout(tf.gather(target_emb_weights, mlabels), 0.7)
  use_gold = 1.0 - tf.cast(global_step, tf.float32) / (1000. * decay_factor)
  use_gold = tf.maximum(use_gold, 0.2) * do_training
  mem = tf.cond(tf.less(tf.random_uniform([]), use_gold),
                lambda: use_gold * gold + (1.0 - use_gold) * res,
                lambda: res)
  mem = tf.reshape(mem, [-1, 1, 1, nmaps])
  return mem, mem_loss, update_mem 
Example #3
Source File: common_attention.py    From fine-lm with MIT License 6 votes vote down vote up
def add_positional_embedding(x, max_length, name, positions=None):
  """Add positional embedding.

  Args:
    x: a Tensor with shape [batch, length, depth]
    max_length: an integer.  static maximum size of any dimension.
    name: a name for this layer.
    positions: an optional tensor with shape [batch, length]

  Returns:
    a Tensor the same shape as x.
  """
  _, length, depth = common_layers.shape_list(x)
  var = tf.cast(tf.get_variable(name, [max_length, depth]), x.dtype)
  if positions is None:
    sliced = tf.cond(
        tf.less(length, max_length),
        lambda: tf.slice(var, [0, 0], [length, -1]),
        lambda: tf.pad(var, [[0, length - max_length], [0, 0]]))
    return x + tf.expand_dims(sliced, 0)
  else:
    return x + tf.gather(var, tf.to_int32(positions)) 
Example #4
Source File: spatial_transformer.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
    """Batch Spatial Transformer Layer

    Parameters
    ----------

    U : float
        tensor of inputs [num_batch,height,width,num_channels]
    thetas : float
        a set of transformations for each input [num_batch,num_transforms,6]
    out_size : int
        the size of the output [out_height,out_width]

    Returns: float
        Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
    """
    with tf.variable_scope(name):
        num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
        indices = [[i]*num_transforms for i in xrange(num_batch)]
        input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
        return transformer(input_repeated, thetas, out_size) 
Example #5
Source File: probclass.py    From imgcomp-cvpr with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, pc: _Network3D, config, centers, sess, freqs_resolution=1e9):
        """
        :param sess: Must be set at the latest before using get_pr or get_freqs
        """
        self.pc_class = pc.__class__
        self.config = config
        self.input_ctx_shape = self.pc_class.get_context_shape(config)
        self.input_ctx = tf.placeholder(tf.int64, self.input_ctx_shape)  # symbols!
        input_ctx_batched = tf.expand_dims(self.input_ctx, 0)  # add batch dimension, 1DHW
        input_ctx_batched = tf.expand_dims(input_ctx_batched, -1)  # add T dimension for 3d conv, now 1CHW1
        # Here, in contrast to pc.bitcost(...), q does not need to be padded, as it is part of some context.
        # Logits will be a 1111L vector, i.e., prediction of the next pixel
        q = tf.gather(centers, input_ctx_batched)
        logits = pc.logits(q, is_training=False)
        self.pr = tf.nn.softmax(logits)
        self.freqs = tf.squeeze(tf.cast(self.pr * freqs_resolution, tf.int64))
        self.sess = sess

        self._get_freqs = None 
Example #6
Source File: shape_utils.py    From object_detector_app with MIT License 6 votes vote down vote up
def clip_tensor(t, length):
  """Clips the input tensor along the first dimension up to the length.

  Args:
    t: the input tensor, assuming the rank is at least 1.
    length: a tensor of shape [1]  or an integer, indicating the first dimension
      of the input tensor t after clipping, assuming length <= t.shape[0].

  Returns:
    clipped_t: the clipped tensor, whose first dimension is length. If the
      length is an integer, the first dimension of clipped_t is set to length
      statically.
  """
  clipped_t = tf.gather(t, tf.range(length))
  if not _is_tensor(length):
    clipped_t = _set_dim_0(clipped_t, length)
  return clipped_t 
Example #7
Source File: box_list_ops.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def filter_field_value_equals(boxlist, field, value, scope=None):
  """Filter to keep only boxes with field entries equal to the given value.

  Args:
    boxlist: BoxList holding N boxes.
    field: field name for filtering.
    value: scalar value.
    scope: name scope.

  Returns:
    a BoxList holding M boxes where M <= N

  Raises:
    ValueError: if boxlist not a BoxList object or if it does not have
      the specified field.
  """
  with tf.name_scope(scope, 'FilterFieldValueEquals'):
    if not isinstance(boxlist, box_list.BoxList):
      raise ValueError('boxlist must be a BoxList')
    if not boxlist.has_field(field):
      raise ValueError('boxlist must contain the specified field')
    filter_field = boxlist.get_field(field)
    gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1])
    return gather(boxlist, gather_index) 
Example #8
Source File: common_attention.py    From fine-lm with MIT License 6 votes vote down vote up
def get_shifted_center_blocks(x, indices):
  """Get right shifted blocks for masked local attention 2d.

  Args:
    x: A tensor with shape [batch, heads, height, width, depth]
    indices: The indices to gather blocks

  Returns:
    x_shifted: a tensor of extracted blocks, each block right shifted along
      length.
  """
  center_x = gather_blocks_2d(x, indices)

  # Shift right along the length dimension
  def shift_right_2d_blocks(x):
    """Shift the second to last dimension of x right by one."""
    shifted_targets = (
        tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0], [0, 0]])[:, :, :, :-1, :])
    return shifted_targets

  x_shifted = shift_right_2d_blocks(center_x)
  return x_shifted 
Example #9
Source File: common_layers.py    From fine-lm with MIT License 6 votes vote down vote up
def convert_gradient_to_tensor(x):
  """Identity operation whose gradient is converted to a `Tensor`.

  Currently, the gradient to `tf.concat` is particularly expensive to
  compute if dy is an `IndexedSlices` (a lack of GPU implementation
  forces the gradient operation onto CPU).  This situation occurs when
  the output of the `tf.concat` is eventually passed to `tf.gather`.
  It is sometimes faster to convert the gradient to a `Tensor`, so as
  to get the cheaper gradient for `tf.concat`.  To do this, replace
  `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`.

  Args:
    x: A `Tensor`.

  Returns:
    The input `Tensor`.
  """
  return x 
Example #10
Source File: memory.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def data(self, rows=None):
    """Access a batch of episodes from the memory.

    Padding elements after the length of each episode are unspecified and might
    contain old data.

    Args:
      rows: Episodes to select, defaults to all.

    Returns:
      Tuple containing a tuple of transition quantiries with batch and time
      dimensions, and a batch of sequence lengths.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    episode = [tf.gather(buffer_, rows) for buffer_ in self._buffers]
    length = tf.gather(self._length, rows)
    return episode, length 
Example #11
Source File: next_frame.py    From fine-lm with MIT License 6 votes vote down vote up
def scheduled_sample(self,
                       ground_truth_x,
                       generated_x,
                       batch_size,
                       num_ground_truth):
    """Sample batch with specified mix of groundtruth and generated data points.

    Args:
      ground_truth_x: tensor of ground-truth data points.
      generated_x: tensor of generated data points.
      batch_size: batch size
      num_ground_truth: number of ground-truth examples to include in batch.
    Returns:
      New batch with num_ground_truth sampled from ground_truth_x and the rest
      from generated_x.
    """
    idx = tf.random_shuffle(tf.range(batch_size))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)
    return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                             [ground_truth_examps, generated_examps]) 
Example #12
Source File: shape_utils.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def clip_tensor(t, length):
  """Clips the input tensor along the first dimension up to the length.

  Args:
    t: the input tensor, assuming the rank is at least 1.
    length: a tensor of shape [1]  or an integer, indicating the first dimension
      of the input tensor t after clipping, assuming length <= t.shape[0].

  Returns:
    clipped_t: the clipped tensor, whose first dimension is length. If the
      length is an integer, the first dimension of clipped_t is set to length
      statically.
  """
  clipped_t = tf.gather(t, tf.range(length))
  if not _is_tensor(length):
    clipped_t = _set_dim_0(clipped_t, length)
  return clipped_t 
Example #13
Source File: bulk_component.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def build_cross_entropy_loss(logits, gold):
  """Constructs a cross entropy from logits and one-hot encoded gold labels.

  Supports skipping rows where the gold label is the magic -1 value.

  Args:
    logits: float Tensor of scores.
    gold: int Tensor of one-hot labels.

  Returns:
    cost, correct, total: the total cost, the total number of correctly
        predicted labels, and the total number of valid labels.
  """
  valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
  gold = tf.gather(gold, valid)
  logits = tf.gather(logits, valid)
  correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  total = tf.size(gold)
  cost = tf.reduce_sum(
      tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
          logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
  return cost, correct, total 
Example #14
Source File: prediction_model.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
  """Sample batch with specified mix of ground truth and generated data points.

  Args:
    ground_truth_x: tensor of ground-truth data points.
    generated_x: tensor of generated data points.
    batch_size: batch size
    num_ground_truth: number of ground-truth examples to include in batch.
  Returns:
    New batch with num_ground_truth sampled from ground_truth_x and the rest
    from generated_x.
  """
  idx = tf.random_shuffle(tf.range(int(batch_size)))
  ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
  generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))

  ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
  generated_examps = tf.gather(generated_x, generated_idx)
  return tf.dynamic_stitch([ground_truth_idx, generated_idx],
                           [ground_truth_examps, generated_examps]) 
Example #15
Source File: memory.py    From DOTA_models with Apache License 2.0 6 votes vote down vote up
def get_hint_pool_idxs(self, normalized_query):
    """Get small set of idxs to compute nearest neighbor queries on.

    This is an expensive look-up on the whole memory that is used to
    avoid more expensive operations later on.

    Args:
      normalized_query: A Tensor of shape [None, key_dim].

    Returns:
      A Tensor of shape [None, choose_k] of indices in memory
      that are closest to the queries.

    """
    # get hash of query vecs
    hash_slot_idxs = self.get_hash_slots(normalized_query)

    # grab mem idxs in the hash slots
    hint_pool_idxs = [
        tf.maximum(tf.minimum(
            tf.gather(self.hash_slots[i], idxs),
            self.memory_size - 1), 0)
        for i, idxs in enumerate(hash_slot_idxs)]

    return tf.concat(axis=1, values=hint_pool_idxs) 
Example #16
Source File: box_list_ops.py    From object_detector_app with MIT License 6 votes vote down vote up
def filter_field_value_equals(boxlist, field, value, scope=None):
  """Filter to keep only boxes with field entries equal to the given value.

  Args:
    boxlist: BoxList holding N boxes.
    field: field name for filtering.
    value: scalar value.
    scope: name scope.

  Returns:
    a BoxList holding M boxes where M <= N

  Raises:
    ValueError: if boxlist not a BoxList object or if it does not have
      the specified field.
  """
  with tf.name_scope(scope, 'FilterFieldValueEquals'):
    if not isinstance(boxlist, box_list.BoxList):
      raise ValueError('boxlist must be a BoxList')
    if not boxlist.has_field(field):
      raise ValueError('boxlist must contain the specified field')
    filter_field = boxlist.get_field(field)
    gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1])
    return gather(boxlist, gather_index) 
Example #17
Source File: memory.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def data(self, rows=None):
    """Access a batch of episodes from the memory.

    Padding elements after the length of each episode are unspecified and might
    contain old data.

    Args:
      rows: Episodes to select, defaults to all.

    Returns:
      Tuple containing a tuple of transition quantiries with batch and time
      dimensions, and a batch of sequence lengths.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    episode = [tf.gather(buffer_, rows) for buffer_ in self._buffers]
    length = tf.gather(self._length, rows)
    return episode, length 
Example #18
Source File: box_list_ops.py    From object_detector_app with MIT License 5 votes vote down vote up
def prune_outside_window(boxlist, window, scope=None):
  """Prunes bounding boxes that fall outside a given window.

  This function prunes bounding boxes that even partially fall outside the given
  window. See also clip_to_window which only prunes bounding boxes that fall
  completely outside the window, and clips any bounding boxes that partially
  overflow.

  Args:
    boxlist: a BoxList holding M_in boxes.
    window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
      of the window
    scope: name scope.

  Returns:
    pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
    valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
     in the input tensor.
  """
  with tf.name_scope(scope, 'PruneOutsideWindow'):
    y_min, x_min, y_max, x_max = tf.split(
        value=boxlist.get(), num_or_size_splits=4, axis=1)
    win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
    coordinate_violations = tf.concat([
        tf.less(y_min, win_y_min), tf.less(x_min, win_x_min),
        tf.greater(y_max, win_y_max), tf.greater(x_max, win_x_max)
    ], 1)
    valid_indices = tf.reshape(
        tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
    return gather(boxlist, valid_indices), valid_indices 
Example #19
Source File: common_attention.py    From fine-lm with MIT License 5 votes vote down vote up
def _generate_relative_positions_embeddings(length, depth,
                                            max_relative_position, name):
  """Generates tensor of size [length, length, depth]."""
  with tf.variable_scope(name):
    relative_positions_matrix = _generate_relative_positions_matrix(
        length, max_relative_position)
    vocab_size = max_relative_position * 2 + 1
    # Generates embedding for each relative position of dimension depth.
    embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
    embeddings = tf.gather(embeddings_table, relative_positions_matrix)
    return embeddings 
Example #20
Source File: box_list_ops.py    From object_detector_app with MIT License 5 votes vote down vote up
def prune_completely_outside_window(boxlist, window, scope=None):
  """Prunes bounding boxes that fall completely outside of the given window.

  The function clip_to_window prunes bounding boxes that fall
  completely outside the window, but also clips any bounding boxes that
  partially overflow. This function does not clip partially overflowing boxes.

  Args:
    boxlist: a BoxList holding M_in boxes.
    window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
      of the window
    scope: name scope.

  Returns:
    pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
    valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
     in the input tensor.
  """
  with tf.name_scope(scope, 'PruneCompleteleyOutsideWindow'):
    y_min, x_min, y_max, x_max = tf.split(
        value=boxlist.get(), num_or_size_splits=4, axis=1)
    win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
    coordinate_violations = tf.concat([
        tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max),
        tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min)
    ], 1)
    valid_indices = tf.reshape(
        tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
    return gather(boxlist, valid_indices), valid_indices 
Example #21
Source File: box_list_ops.py    From object_detector_app with MIT License 5 votes vote down vote up
def boolean_mask(boxlist, indicator, fields=None, scope=None):
  """Select boxes from BoxList according to indicator and return new BoxList.

  `boolean_mask` returns the subset of boxes that are marked as "True" by the
  indicator tensor. By default, `boolean_mask` returns boxes corresponding to
  the input index list, as well as all additional fields stored in the boxlist
  (indexing into the first dimension).  However one can optionally only draw
  from a subset of fields.

  Args:
    boxlist: BoxList holding N boxes
    indicator: a rank-1 boolean tensor
    fields: (optional) list of fields to also gather from.  If None (default),
      all fields are gathered from.  Pass an empty fields list to only gather
      the box coordinates.
    scope: name scope.

  Returns:
    subboxlist: a BoxList corresponding to the subset of the input BoxList
      specified by indicator
  Raises:
    ValueError: if `indicator` is not a rank-1 boolean tensor.
  """
  with tf.name_scope(scope, 'BooleanMask'):
    if indicator.shape.ndims != 1:
      raise ValueError('indicator should have rank 1')
    if indicator.dtype != tf.bool:
      raise ValueError('indicator should be a boolean tensor')
    subboxlist = box_list.BoxList(tf.boolean_mask(boxlist.get(), indicator))
    if fields is None:
      fields = boxlist.get_extra_fields()
    for field in fields:
      if not boxlist.has_field(field):
        raise ValueError('boxlist must contain all specified fields')
      subfieldlist = tf.boolean_mask(boxlist.get_field(field), indicator)
      subboxlist.add_field(field, subfieldlist)
    return subboxlist 
Example #22
Source File: box_list_ops.py    From object_detector_app with MIT License 5 votes vote down vote up
def gather(boxlist, indices, fields=None, scope=None):
  """Gather boxes from BoxList according to indices and return new BoxList.

  By default, `gather` returns boxes corresponding to the input index list, as
  well as all additional fields stored in the boxlist (indexing into the
  first dimension).  However one can optionally only gather from a
  subset of fields.

  Args:
    boxlist: BoxList holding N boxes
    indices: a rank-1 tensor of type int32 / int64
    fields: (optional) list of fields to also gather from.  If None (default),
      all fields are gathered from.  Pass an empty fields list to only gather
      the box coordinates.
    scope: name scope.

  Returns:
    subboxlist: a BoxList corresponding to the subset of the input BoxList
    specified by indices
  Raises:
    ValueError: if specified field is not contained in boxlist or if the
      indices are not of type int32
  """
  with tf.name_scope(scope, 'Gather'):
    if len(indices.shape.as_list()) != 1:
      raise ValueError('indices should have rank 1')
    if indices.dtype != tf.int32 and indices.dtype != tf.int64:
      raise ValueError('indices should be an int32 / int64 tensor')
    subboxlist = box_list.BoxList(tf.gather(boxlist.get(), indices))
    if fields is None:
      fields = boxlist.get_extra_fields()
    for field in fields:
      if not boxlist.has_field(field):
        raise ValueError('boxlist must contain all specified fields')
      subfieldlist = tf.gather(boxlist.get_field(field), indices)
      subboxlist.add_field(field, subfieldlist)
    return subboxlist 
Example #23
Source File: probclass.py    From imgcomp-cvpr with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, pc: _Network3D, ae, sess):
        """
        :param pc: Probability classifier network
        :param ae: Auotencoder
        :param sess: session to run
        """
        self.input_ph_symbols = tf.placeholder(tf.int64, shape=(None, None, None, None))  # NCHW

        # get q from symbols
        centers = ae.get_centers_variable()
        q = tf.gather(centers, self.input_ph_symbols)
        bit_cost_per_symbol = pc.bitcost(q, self.input_ph_symbols, is_training=False,
                                         pad_value=pc.auto_pad_value(ae))
        self.bit_cost_total = tf.reduce_sum(bit_cost_per_symbol)  # single value
        self.sess = sess 
Example #24
Source File: common_layers.py    From fine-lm with MIT License 5 votes vote down vote up
def embedding(x,
              vocab_size,
              dense_size,
              name=None,
              reuse=None,
              multiplier=1.0,
              symbol_dropout_rate=0.0,
              embedding_var=None,
              dtype=tf.float32):
  """Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
  with tf.variable_scope(
      name, default_name="embedding", values=[x], reuse=reuse, dtype=dtype):
    if embedding_var is None:
      embedding_var = tf.get_variable("kernel", [vocab_size, dense_size])
    # On the backwards pass, we want to convert the gradient from
    # an indexed-slices to a regular tensor before sending it back to the
    # parameter server. This avoids excess computation on the parameter server.
    if not tf.contrib.eager.in_eager_mode():
      embedding_var = convert_gradient_to_tensor(embedding_var)
    x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
    emb_x = gather(embedding_var, x, dtype)
    if multiplier != 1.0:
      emb_x *= multiplier
    static_shape = emb_x.shape.as_list()
    if len(static_shape) < 5:
      return emb_x
    assert len(static_shape) == 5
    # If we had an extra channel dimension, assume it's 1, i.e. shape[3] == 1.
    return tf.squeeze(emb_x, 3) 
Example #25
Source File: common_layers.py    From fine-lm with MIT License 5 votes vote down vote up
def flatten4d3d(x):
  """Flatten a 4d-tensor into a 3d-tensor by joining width and height."""
  xshape = shape_list(x)
  result = tf.reshape(x, [xshape[0], xshape[1] * xshape[2], xshape[3]])
  return result


# TODO(noam): remove this function after TPUs do gather faster. 
Example #26
Source File: expert_utils.py    From fine-lm with MIT License 5 votes vote down vote up
def dispatch(self, inp):
    """Send the inputs to the experts.

    Args:
      inp: a `Tensor` of shape "[batch, length, depth]`
    Returns:
      a tensor with shape [batch, num_experts, expert_capacity, depth]
    """
    inp = tf.reshape(inp, [self._batch * self._length, -1])
    # [batch, num_experts, expert_capacity, depth]
    ret = tf.gather(inp, self._flat_indices)
    return ret 
Example #27
Source File: expert_utils.py    From fine-lm with MIT License 5 votes vote down vote up
def dispatch(self, inp):
    """Create one input Tensor for each expert.

    The `Tensor` for a expert `i` contains the slices of `inp` corresponding
    to the batch elements `b` where `gates[b, i] > 0`.

    Args:
      inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
    Returns:
      a list of `num_experts` `Tensor`s with shapes
        `[expert_batch_size_i, <extra_input_dims>]`.
    """
    inp = tf.gather(inp, self._batch_index)
    return tf.split(inp, self._part_sizes_tensor, 0, num=self._num_experts) 
Example #28
Source File: simulated_batch_env.py    From fine-lm with MIT License 5 votes vote down vote up
def _reset_non_empty(self, indices):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    with tf.control_dependencies([self.history_buffer.reset(indices)]):
      with tf.control_dependencies([self._observ.assign(
          self.history_buffer.get_all_elements()[:, -1, ...])]):
        return tf.gather(self._observ.read_value(), indices) 
Example #29
Source File: simulated_batch_env.py    From fine-lm with MIT License 5 votes vote down vote up
def reset(self, indices):
    initial_frames = tf.gather(self.get_initial_observations(), indices)
    scatter_op = tf.scatter_update(self._history_buff, indices, initial_frames)
    with tf.control_dependencies([scatter_op]):
      return self._history_buff.read_value() 
Example #30
Source File: tf_atari_wrappers.py    From fine-lm with MIT License 5 votes vote down vote up
def _reset_non_empty(self, indices):
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      new_values = self._batch_env._reset_non_empty(indices)  # pylint: disable=protected-access
      ret = self.autoencoder_model.encode(new_values)
      assign_op = tf.scatter_update(self._observ, indices, ret)
      with tf.control_dependencies([assign_op]):
        return tf.gather(self.observ, indices)