Python tensorflow.scatter_update() Examples

The following are 30 code examples of tensorflow.scatter_update(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: yellowfin.py    From MobileNet with Apache License 2.0 6 votes vote down vote up
def curvature_range(self):
    # set up the curvature window
    self._curv_win = \
      tf.Variable(np.zeros( [self._curv_win_width, ] ), dtype=tf.float32, name="curv_win", trainable=False)
    self._curv_win = tf.scatter_update(self._curv_win, 
      self._global_step % self._curv_win_width, self._grad_norm_squared)
    # note here the iterations start from iteration 0
    valid_window = tf.slice(self._curv_win, tf.constant( [0, ] ), 
      tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0) )
    self._h_min_t = tf.reduce_min(valid_window)
    self._h_max_t = tf.reduce_max(valid_window)

    curv_range_ops = []
    with tf.control_dependencies([self._h_min_t, self._h_max_t] ):
      avg_op = self._moving_averager.apply([self._h_min_t, self._h_max_t] )
      with tf.control_dependencies([avg_op] ):
        self._h_min = tf.identity(self._moving_averager.average(self._h_min_t) )
        self._h_max = tf.identity(self._moving_averager.average(self._h_max_t) )
    curv_range_ops.append(avg_op)
    return curv_range_ops 
Example #2
Source File: in_graph_batch_env.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def reset(self, indices=None):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    if indices is None:
      indices = tf.range(len(self._batch_env))
    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    observ = tf.check_numerics(observ, 'observ')
    reward = tf.zeros_like(indices, tf.float32)
    done = tf.zeros_like(indices, tf.bool)
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ),
        tf.scatter_update(self._reward, indices, reward),
        tf.scatter_update(self._done, indices, done)]):
      return tf.identity(observ) 
Example #3
Source File: optimizer.py    From Parser-v3 with Apache License 2.0 6 votes vote down vote up
def sparse_moving_average(self, variable, unique_indices, accumulant, name='Accumulator', decay=.9):
    """"""
    
    accumulant = tf.clip_by_value(accumulant, -self.clip, self.clip)
    first_dim = variable.get_shape().as_list()[0]
    accumulator = self.get_accumulator(name, variable)
    indexed_accumulator = tf.gather(accumulator, unique_indices)
    iteration = self.get_accumulator('{}/iteration'.format(name), variable, shape=[first_dim, 1])
    indexed_iteration = tf.gather(iteration, unique_indices)
    iteration = tf.scatter_add(iteration, unique_indices, tf.ones_like(indexed_iteration))
    indexed_iteration = tf.gather(iteration, unique_indices)
    
    if decay < 1:
      current_indexed_decay = decay * (1-decay**(indexed_iteration-1)) / (1-decay**indexed_iteration)
    else:
      current_indexed_decay = (indexed_iteration-1) / indexed_iteration
    
    accumulator = tf.scatter_update(accumulator, unique_indices, current_indexed_decay*indexed_accumulator)
    accumulator = tf.scatter_add(accumulator, unique_indices, (1-current_indexed_decay)*accumulant)
    return accumulator, iteration
  
  #============================================================= 
Example #4
Source File: utility.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def reinit_nested_vars(variables, indices=None):
  """Reset all variables in a nested tuple to zeros.

  Args:
    variables: Nested tuple or list of variaables.
    indices: Batch indices to reset, defaults to all.

  Returns:
    Operation.
  """
  if isinstance(variables, (tuple, list)):
    return tf.group(*[
        reinit_nested_vars(variable, indices) for variable in variables])
  if indices is None:
    return variables.assign(tf.zeros_like(variables))
  else:
    zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
    return tf.scatter_update(variables, indices, zeros) 
Example #5
Source File: in_graph_batch_env.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def reset(self, indices=None):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    if indices is None:
      indices = tf.range(len(self._batch_env))
    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    observ = tf.check_numerics(observ, 'observ')
    reward = tf.zeros_like(indices, tf.float32)
    done = tf.zeros_like(indices, tf.bool)
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ),
        tf.scatter_update(self._reward, indices, reward),
        tf.scatter_update(self._done, indices, done)]):
      return tf.identity(observ) 
Example #6
Source File: utility.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def assign_nested_vars(variables, tensors, indices=None):
  """Assign tensors to matching nested tuple of variables.

  Args:
    variables: Nested tuple or list of variables to update.
    tensors: Nested tuple or list of tensors to assign.
    indices: Batch indices to assign to; default to all.

  Returns:
    Operation.
  """
  if isinstance(variables, (tuple, list)):
    return tf.group(*[
        assign_nested_vars(variable, tensor)
        for variable, tensor in zip(variables, tensors)])
  if indices is None:
    return variables.assign(tensors)
  else:
    return tf.scatter_update(variables, indices, tensors) 
Example #7
Source File: memory.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #8
Source File: tf_atari_wrappers.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def _reset_non_empty(self, indices):
    # pylint: disable=protected-access
    new_values = self._batch_env._reset_non_empty(indices)
    # pylint: enable=protected-access
    initial_frames = getattr(self._batch_env, "history_observations", None)
    if initial_frames is not None:
      # Using history buffer frames for initialization, if they are available.
      with tf.control_dependencies([new_values]):
        # Transpose to [batch, height, width, history, channels] and merge
        # history and channels into one dimension.
        initial_frames = tf.transpose(initial_frames, [0, 2, 3, 1, 4])
        initial_frames = tf.reshape(initial_frames,
                                    (len(self),) + self.observ_shape)
    else:
      inx = tf.concat(
          [
              tf.ones(tf.size(tf.shape(new_values)),
                      dtype=tf.int64)[:-1],
              [self.history]
          ],
          axis=0)
      initial_frames = tf.tile(new_values, inx)
    assign_op = tf.scatter_update(self._observ, indices, initial_frames)
    with tf.control_dependencies([assign_op]):
      return tf.gather(self.observ, indices) 
Example #9
Source File: tf_atari_wrappers.py    From training_results_v0.5 with Apache License 2.0 6 votes vote down vote up
def _reset_non_empty(self, indices):
    # pylint: disable=protected-access
    new_values = self._batch_env._reset_non_empty(indices)
    # pylint: enable=protected-access
    initial_frames = getattr(self._batch_env, "history_observations", None)
    if initial_frames is not None:
      # Using history buffer frames for initialization, if they are available.
      with tf.control_dependencies([new_values]):
        # Transpose to [batch, height, width, history, channels] and merge
        # history and channels into one dimension.
        initial_frames = tf.transpose(initial_frames, [0, 2, 3, 1, 4])
        initial_frames = tf.reshape(initial_frames,
                                    (len(self),) + self.observ_shape)
    else:
      inx = tf.concat(
          [
              tf.ones(tf.size(tf.shape(new_values)),
                      dtype=tf.int64)[:-1],
              [self.history]
          ],
          axis=0)
      initial_frames = tf.tile(new_values, inx)
    assign_op = tf.scatter_update(self._observ, indices, initial_frames)
    with tf.control_dependencies([assign_op]):
      return tf.gather(self.observ, indices) 
Example #10
Source File: memory.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #11
Source File: control_flow_ops_py_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testWhileUpdateVariable_1(self):
    with self.test_session():
      select = tf.Variable([3.0, 4.0, 5.0])
      n = tf.constant(0)

      def loop_iterator(j):
        return tf.less(j, 3)

      def loop_body(j):
        ns = tf.scatter_update(select, j, 10.0)
        nj = tf.add(j, 1)
        op = control_flow_ops.group(ns)
        nj = control_flow_ops.with_dependencies([op], nj)
        return [nj]

      r = tf.while_loop(loop_iterator, loop_body, [n],
                        parallel_iterations=1)
      self.assertTrue(check_op_order(n.graph))
      tf.global_variables_initializer().run()
      self.assertEqual(3, r.eval())
      result = select.eval()
      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 
Example #12
Source File: control_flow_ops_py_test.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def testWhileUpdateVariable_3(self):
    with self.test_session():
      select = tf.Variable([3.0, 4.0, 5.0])
      n = tf.constant(0)

      def loop_iterator(j, _):
        return tf.less(j, 3)

      def loop_body(j, _):
        ns = tf.scatter_update(select, j, 10.0)
        nj = tf.add(j, 1)
        return [nj, ns]

      r = tf.while_loop(loop_iterator, loop_body,
                        [n, tf.identity(select)],
                        parallel_iterations=1)
      tf.global_variables_initializer().run()
      result = r[1].eval()
    self.assertTrue(check_op_order(n.graph))
    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)

  # b/24814703 
Example #13
Source File: topn.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def insert(self, ids, scores):
    """Insert the ids and scores into the TopN."""
    with tf.control_dependencies(self.last_ops):
      scatter_op = tf.scatter_update(self.id_to_score, ids, scores)
      larger_scores = tf.greater(scores, self.sl_scores[0])

      def shortlist_insert():
        larger_ids = tf.boolean_mask(tf.to_int64(ids), larger_scores)
        larger_score_values = tf.boolean_mask(scores, larger_scores)
        shortlist_ids, new_ids, new_scores = self.ops.top_n_insert(
            self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
        u1 = tf.scatter_update(self.sl_ids, shortlist_ids, new_ids)
        u2 = tf.scatter_update(self.sl_scores, shortlist_ids, new_scores)
        return tf.group(u1, u2)

      # We only need to insert into the shortlist if there are any
      # scores larger than the threshold.
      cond_op = tf.cond(
          tf.reduce_any(larger_scores), shortlist_insert, tf.no_op)
      with tf.control_dependencies([cond_op]):
        self.last_ops = [scatter_op, cond_op] 
Example #14
Source File: topn.py    From deep_image_model with Apache License 2.0 6 votes vote down vote up
def remove(self, ids):
    """Remove the ids (and their associated scores) from the TopN."""
    with tf.control_dependencies(self.last_ops):
      scatter_op = tf.scatter_update(
          self.id_to_score,
          ids,
          tf.ones_like(
              ids, dtype=tf.float32) * tf.float32.min)
      # We assume that removed ids are almost always in the shortlist,
      # so it makes no sense to hide the Op behind a tf.cond
      shortlist_ids_to_remove, new_length = self.ops.top_n_remove(self.sl_ids,
                                                                  ids)
      u1 = tf.scatter_update(
          self.sl_ids, tf.concat(0, [[0], shortlist_ids_to_remove]),
          tf.concat(0, [new_length,
                        tf.ones_like(shortlist_ids_to_remove) * -1]))
      u2 = tf.scatter_update(
          self.sl_scores,
          shortlist_ids_to_remove,
          tf.float32.min * tf.ones_like(
              shortlist_ids_to_remove, dtype=tf.float32))
      self.last_ops = [scatter_op, u1, u2] 
Example #15
Source File: in_graph_batch_env.py    From planet with Apache License 2.0 6 votes vote down vote up
def reset(self, indices=None):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    if indices is None:
      indices = tf.range(len(self._batch_env))
    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    reward = tf.zeros_like(indices, tf.float32)
    done = tf.zeros_like(indices, tf.int32)
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ),
        tf.scatter_update(self._reward, indices, reward),
        tf.scatter_update(self._done, indices, tf.to_int32(done))]):
      return tf.identity(observ) 
Example #16
Source File: memory.py    From soccer-matlab with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def replace(self, episodes, length, rows=None):
    """Replace full episodes.

    Args:
      episodes: Tuple of transition quantities with batch and time dimensions.
      length: Batch of sequence lengths.
      rows: Episodes to replace, defaults to all.

    Returns:
      Operation.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    assert_capacity = tf.assert_less(
        rows, self._capacity, message='capacity exceeded')
    with tf.control_dependencies([assert_capacity]):
      assert_max_length = tf.assert_less_equal(
          length, self._max_length, message='max length exceeded')
    replace_ops = []
    with tf.control_dependencies([assert_max_length]):
      for buffer_, elements in zip(self._buffers, episodes):
        replace_op = tf.scatter_update(buffer_, rows, elements)
        replace_ops.append(replace_op)
    with tf.control_dependencies(replace_ops):
      return tf.scatter_update(self._length, rows, length) 
Example #17
Source File: yellowfin.py    From MobileNet with Apache License 2.0 6 votes vote down vote up
def get_mu_tensor(self):
    const_fact = self._dist_to_opt_avg**2 * self._h_min**2 / 2 / self._grad_var
    coef = tf.Variable([-1.0, 3.0, 0.0, 1.0], dtype=tf.float32, name="cubic_solver_coef")
    coef = tf.scatter_update(coef, tf.constant(2), -(3 + const_fact) )        
    roots = tf.py_func(np.roots, [coef], Tout=tf.complex64, stateful=False)
    
    # filter out the correct root
    root_idx = tf.logical_and(tf.logical_and(tf.greater(tf.real(roots), tf.constant(0.0) ),
      tf.less(tf.real(roots), tf.constant(1.0) ) ), tf.less(tf.abs(tf.imag(roots) ), 1e-5) )
    # in case there are two duplicated roots satisfying the above condition
    root = tf.reshape(tf.gather(tf.gather(roots, tf.where(root_idx) ), tf.constant(0) ), shape=[] )
    tf.assert_equal(tf.size(root), tf.constant(1) )

    dr = self._h_max / self._h_min
    mu = tf.maximum(tf.real(root)**2, ( (tf.sqrt(dr) - 1)/(tf.sqrt(dr) + 1) )**2)    
    return mu 
Example #18
Source File: memory.py    From batch-ppo with Apache License 2.0 6 votes vote down vote up
def replace(self, episodes, length, rows=None):
    """Replace full episodes.

    Args:
      episodes: Tuple of transition quantities with batch and time dimensions.
      length: Batch of sequence lengths.
      rows: Episodes to replace, defaults to all.

    Returns:
      Operation.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    assert_capacity = tf.assert_less(
        rows, self._capacity, message='capacity exceeded')
    with tf.control_dependencies([assert_capacity]):
      assert_max_length = tf.assert_less_equal(
          length, self._max_length, message='max length exceeded')
    with tf.control_dependencies([assert_max_length]):
      replace_ops = tools.nested.map(
          lambda var, val: tf.scatter_update(var, rows, val),
          self._buffers, episodes, flatten=True)
    with tf.control_dependencies(replace_ops):
      return tf.scatter_update(self._length, rows, length) 
Example #19
Source File: in_graph_batch_env.py    From batch-ppo with Apache License 2.0 6 votes vote down vote up
def reset(self, indices=None):
    """Reset the batch of environments.

    Args:
      indices: The batch indices of the environments to reset; defaults to all.

    Returns:
      Batch tensor of the new observations.
    """
    if indices is None:
      indices = tf.range(len(self._batch_env))
    observ_dtype = self._parse_dtype(self._batch_env.observation_space)
    observ = tf.py_func(
        self._batch_env.reset, [indices], observ_dtype, name='reset')
    observ = tf.check_numerics(observ, 'observ')
    reward = tf.zeros_like(indices, tf.float32)
    done = tf.zeros_like(indices, tf.bool)
    with tf.control_dependencies([
        tf.scatter_update(self._observ, indices, observ),
        tf.scatter_update(self._reward, indices, reward),
        tf.scatter_update(self._done, indices, done)]):
      return tf.identity(observ) 
Example #20
Source File: utility.py    From batch-ppo with Apache License 2.0 6 votes vote down vote up
def reinit_nested_vars(variables, indices=None):
  """Reset all variables in a nested tuple to zeros.

  Args:
    variables: Nested tuple or list of variables.
    indices: Batch indices to reset, defaults to all.

  Returns:
    Operation.
  """
  if isinstance(variables, (tuple, list)):
    return tf.group(*[
        reinit_nested_vars(variable, indices) for variable in variables])
  if indices is None:
    return variables.assign(tf.zeros_like(variables))
  else:
    zeros = tf.zeros([tf.shape(indices)[0]] + variables.shape[1:].as_list())
    return tf.scatter_update(variables, indices, zeros) 
Example #21
Source File: latent_factor.py    From openrec with Apache License 2.0 6 votes vote down vote up
def censor_l2_norm_op(self, censor_id_list=None, max_norm=1):

        """Limit the norm of embeddings.
        
        Parameters
        ----------
        censor_id_list: list or Tensorflow tensor
            list of embeddings to censor (indexed by ids).
        max_norm: float, optional
            Maximum norm.

        Returns
        -------
        Tensorflow operator
            An operator for post-training execution.
        """

        
        embedding_gather = tf.gather(self._embedding, indices=censor_id_list)
        norm = tf.sqrt(tf.reduce_sum(tf.square(embedding_gather), axis=1, keep_dims=True))
        return tf.scatter_update(self._embedding, indices=censor_id_list, updates=embedding_gather / tf.maximum(norm, max_norm)) 
Example #22
Source File: model.py    From professional-services with Apache License 2.0 6 votes vote down vote up
def _update_embedding_matrix(row_indices, rows, embedding_size, tft_output,
                             vocab_name):
  """Creates and maintains a lookup table of embeddings for inference.

  Args:
    row_indices: indices of rows of the lookup table to update.
    rows: the values to update the lookup table with.
    embedding_size: the size of the embedding.
    tft_output: a TFTransformOutput object.
    vocab_name: a tft vocabulary name.

  Returns:
    A num_items x embedding_size table of the latest embeddings with the given
      rows updated.
  """
  embedding = _get_embedding_matrix(embedding_size, tft_output, vocab_name)
  return tf.scatter_update(embedding, row_indices, rows) 
Example #23
Source File: association.py    From Deep-Association-Learning with MIT License 6 votes vote down vote up
def update_cross_anchor(cross_anchors, intra_anchors_n, intra_anchors_batch_n,
                        labels_cam, start_sign):
    # update cross-anchor
    for i in range(FLAGS.num_cams):
        # other_anchors: all the anchors under other cameras
        other_anchors_n = []
        [other_anchors_n.append(intra_anchors_n[x]) for x in range(FLAGS.num_cams) if x is not i]
        other_anchors_n = tf.concat(other_anchors_n, 0)

        consistent, rank1_anchors = \
            cyclic_ranking(intra_anchors_batch_n[i], intra_anchors_n[i], other_anchors_n, labels_cam[i], start_sign)

        # if the consistency fulfills, update by
        # merging with the best-matched rank1 anchors in another camera
        update = tf.where(consistent, (intra_anchors_batch_n[i] + rank1_anchors) / 2, intra_anchors_batch_n[i])

        # update the associate centers under each camera
        cross_anchors[i] = tf.scatter_update(cross_anchors[i], labels_cam[i], update)

    return cross_anchors 
Example #24
Source File: memory.py    From hands-detection with MIT License 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #25
Source File: tensorlfowapi.py    From SSD_tensorflow_VOC with Apache License 2.0 6 votes vote down vote up
def test_scatter_nd_2():
    gt_bboxes = tf.constant([[0,0,1,2],[1,0,3,4],[100,100,105,102.5]])
    gt_labels = tf.constant([1,2,6])
  
    
    gt_anchors_labels = tf.Variable([100,100,100,100], trainable=False,collections=[ops.GraphKeys.LOCAL_VARIABLES])
    gt_anchors_bboxes=tf.Variable([[100,100,105,105],[2,1,3,3.5],[0,0,10,10],[0.5,0.5,0.8,1.5]], trainable=False,collections=[ops.GraphKeys.LOCAL_VARIABLES],dtype=tf.float32)
    
    max_inds = [1,0,3]
    
    gt_anchors_labels = tf.scatter_update(gt_anchors_labels, max_inds,gt_labels)
    gt_anchors_bboxes = tf.scatter_update(gt_anchors_bboxes, max_inds,gt_bboxes)
    
    
    
    
   
    return gt_anchors_labels,gt_anchors_bboxes 
Example #26
Source File: memory.py    From object_detection_kitti with Apache License 2.0 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #27
Source File: memory.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #28
Source File: memory.py    From HumanRecognition with MIT License 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #29
Source File: memory.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd) 
Example #30
Source File: memory.py    From models with Apache License 2.0 6 votes vote down vote up
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
                     batch_size, use_recent_idx, intended_output):
    """Function that creates all the update ops."""
    mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
                                                   dtype=tf.float32))
    with tf.control_dependencies([mem_age_incr]):
      mem_age_upd = tf.scatter_update(
          self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))

    mem_key_upd = tf.scatter_update(
        self.mem_keys, upd_idxs, upd_keys)
    mem_val_upd = tf.scatter_update(
        self.mem_vals, upd_idxs, upd_vals)

    if use_recent_idx:
      recent_idx_upd = tf.scatter_update(
          self.recent_idx, intended_output, upd_idxs)
    else:
      recent_idx_upd = tf.group()

    return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)