Python tensorflow.compat.v1.scatter_nd() Examples

The following are 6 code examples of tensorflow.compat.v1.scatter_nd(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: expert_utils.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def restore(self, x):
    """Add padding back to the given tensor.

    Args:
      x (tf.Tensor): of shape [dim_compressed,...]

    Returns:
      a tensor of shape [dim_origin,...] with dim_compressed >= dim_origin. The
      dim is restored from the original reference tensor
    """
    with tf.name_scope("pad_reduce/restore"):
      x = tf.scatter_nd(
          indices=self.nonpad_ids,
          updates=x,
          shape=tf.concat([self.dim_origin, tf.shape(x)[1:]], axis=0),
      )
    return x 
Example #2
Source File: lstm_utils.py    From magenta with Apache License 2.0 5 votes vote down vote up
def set_final(sequence, sequence_length, values, time_major=False):
  """Sets the final values in a batch of sequences, and clears those after."""
  sequence_batch_major = (
      sequence if not time_major else tf.transpose(sequence, [1, 0, 2]))
  final_index = _get_final_index(sequence_length, time_major=False)
  mask = tf.sequence_mask(
      tf.maximum(0, sequence_length - 1),
      maxlen=sequence_batch_major.shape[1],
      dtype=tf.float32)
  sequence_batch_major = (
      tf.expand_dims(mask, axis=-1) * sequence_batch_major +
      tf.scatter_nd(final_index, values, tf.shape(sequence_batch_major)))
  if time_major:
    return tf.transpose(sequence_batch_major, [1, 0, 2])
  return sequence_batch_major 
Example #3
Source File: initializers.py    From compression with Apache License 2.0 5 votes vote down vote up
def __call__(self, shape, dtype=None, partition_info=None):
    del partition_info  # unused
    assert len(shape) > 2, shape

    support = tuple(shape[:-2]) + (1, 1)
    indices = [[s // 2 for s in support]]
    updates = tf.constant([self.gain], dtype=dtype)
    kernel = tf.scatter_nd(indices, updates, support)

    assert shape[-2] == shape[-1], shape
    if shape[-1] != 1:
      kernel *= tf.eye(shape[-1], dtype=dtype)

    return kernel 
Example #4
Source File: center_net_meta_arch.py    From models with Apache License 2.0 5 votes vote down vote up
def _pad_to_full_keypoint_dim(keypoint_coords, keypoint_scores, keypoint_inds,
                              num_total_keypoints):
  """Scatter keypoint elements into tensors with full keypoints dimension.

  Args:
    keypoint_coords: a [batch_size, num_instances, num_keypoints, 2] float32
      tensor.
    keypoint_scores: a [batch_size, num_instances, num_keypoints] float32
      tensor.
    keypoint_inds: a list of integers that indicate the keypoint indices for
      this specific keypoint class. These indices are used to scatter into
      tensors that have a `num_total_keypoints` dimension.
    num_total_keypoints: The total number of keypoints that this model predicts.

  Returns:
    A tuple with
    keypoint_coords_padded: a
      [batch_size, num_instances, num_total_keypoints,2] float32 tensor.
    keypoint_scores_padded: a [batch_size, num_instances, num_total_keypoints]
      float32 tensor.
  """
  batch_size, num_instances, _, _ = (
      shape_utils.combined_static_and_dynamic_shape(keypoint_coords))
  kpt_coords_transposed = tf.transpose(keypoint_coords, [2, 0, 1, 3])
  kpt_scores_transposed = tf.transpose(keypoint_scores, [2, 0, 1])
  kpt_inds_tensor = tf.expand_dims(keypoint_inds, axis=-1)
  kpt_coords_scattered = tf.scatter_nd(
      indices=kpt_inds_tensor,
      updates=kpt_coords_transposed,
      shape=[num_total_keypoints, batch_size, num_instances, 2])
  kpt_scores_scattered = tf.scatter_nd(
      indices=kpt_inds_tensor,
      updates=kpt_scores_transposed,
      shape=[num_total_keypoints, batch_size, num_instances])
  keypoint_coords_padded = tf.transpose(kpt_coords_scattered, [1, 2, 0, 3])
  keypoint_scores_padded = tf.transpose(kpt_scores_scattered, [1, 2, 0])
  return keypoint_coords_padded, keypoint_scores_padded 
Example #5
Source File: center_net_meta_arch.py    From models with Apache License 2.0 5 votes vote down vote up
def _pad_to_full_instance_dim(keypoint_coords, keypoint_scores, instance_inds,
                              max_instances):
  """Scatter keypoint elements into tensors with full instance dimension.

  Args:
    keypoint_coords: a [batch_size, num_instances, num_keypoints, 2] float32
      tensor.
    keypoint_scores: a [batch_size, num_instances, num_keypoints] float32
      tensor.
    instance_inds: a list of integers that indicate the instance indices for
      these keypoints. These indices are used to scatter into tensors
      that have a `max_instances` dimension.
    max_instances: The maximum number of instances detected by the model.

  Returns:
    A tuple with
    keypoint_coords_padded: a [batch_size, max_instances, num_keypoints, 2]
      float32 tensor.
    keypoint_scores_padded: a [batch_size, max_instances, num_keypoints]
      float32 tensor.
  """
  batch_size, _, num_keypoints, _ = (
      shape_utils.combined_static_and_dynamic_shape(keypoint_coords))
  kpt_coords_transposed = tf.transpose(keypoint_coords, [1, 0, 2, 3])
  kpt_scores_transposed = tf.transpose(keypoint_scores, [1, 0, 2])
  instance_inds = tf.expand_dims(instance_inds, axis=-1)
  kpt_coords_scattered = tf.scatter_nd(
      indices=instance_inds,
      updates=kpt_coords_transposed,
      shape=[max_instances, batch_size, num_keypoints, 2])
  kpt_scores_scattered = tf.scatter_nd(
      indices=instance_inds,
      updates=kpt_scores_transposed,
      shape=[max_instances, batch_size, num_keypoints])
  keypoint_coords_padded = tf.transpose(kpt_coords_scattered, [1, 0, 2, 3])
  keypoint_scores_padded = tf.transpose(kpt_scores_scattered, [1, 0, 2])
  return keypoint_coords_padded, keypoint_scores_padded 
Example #6
Source File: seq2seq.py    From magenta with Apache License 2.0 4 votes vote down vote up
def next_inputs(self, time, outputs, state, sample_ids, name=None):
    with tf.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
                       [time, outputs, state, sample_ids]):
      (finished, base_next_inputs, state) = (
          super(ScheduledOutputTrainingHelper, self).next_inputs(
              time=time,
              outputs=outputs,
              state=state,
              sample_ids=sample_ids,
              name=name))
      sample_ids = tf.cast(sample_ids, tf.bool)

      def maybe_sample():
        """Perform scheduled sampling."""

        def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
          """Concatenate outputs with auxiliary inputs, if they exist."""
          if self._auxiliary_input_tas is None:
            return outputs_

          next_time = time + 1
          auxiliary_inputs = tf.nest.map_structure(
              lambda ta: ta.read(next_time), self._auxiliary_input_tas)
          if indices is not None:
            auxiliary_inputs = tf.gather_nd(auxiliary_inputs, indices)
          return tf.nest.map_structure(
              lambda x, y: tf.concat((x, y), -1),
              outputs_, auxiliary_inputs)

        if self._next_inputs_fn is None:
          return tf.where(
              sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
              base_next_inputs)

        where_sampling = tf.cast(
            tf.where(sample_ids), tf.int32)
        where_not_sampling = tf.cast(
            tf.where(tf.logical_not(sample_ids)), tf.int32)
        outputs_sampling = tf.gather_nd(outputs, where_sampling)
        inputs_not_sampling = tf.gather_nd(base_next_inputs,
                                           where_not_sampling)
        sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
            self._next_inputs_fn(outputs_sampling), where_sampling)

        base_shape = tf.shape(base_next_inputs)
        return (tf.scatter_nd(indices=where_sampling,
                              updates=sampled_next_inputs,
                              shape=base_shape)
                + tf.scatter_nd(indices=where_not_sampling,
                                updates=inputs_not_sampling,
                                shape=base_shape))

      all_finished = tf.reduce_all(finished)
      no_samples = tf.logical_not(tf.reduce_any(sample_ids))
      next_inputs = tf.cond(
          tf.logical_or(all_finished, no_samples),
          lambda: base_next_inputs, maybe_sample)
      return (finished, next_inputs, state)