Python tensorflow.tensor_scatter_nd_update() Examples

The following are 16 code examples of tensorflow.tensor_scatter_nd_update(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: modes.py    From spektral with MIT License 6 votes vote down vote up
def disjoint_signal_to_batch(X, I):
    """
    Converts a disjoint graph signal to batch node by zero-padding.

    :param X: Tensor, node features of shape (nodes, features).
    :param I: Tensor, graph IDs of shape `(N, )`;
    :return batch: Tensor, batched node features of shape (batch, N_max, F)
    """
    I = tf.cast(I, tf.int32)
    num_nodes = tf.math.segment_sum(tf.ones_like(I), I)
    start_index = tf.cumsum(num_nodes, exclusive=True)
    n_graphs = tf.shape(num_nodes)[0]
    max_n_nodes = tf.reduce_max(num_nodes)
    batch_n_nodes = tf.shape(I)[0]
    feature_dim = tf.shape(X)[-1]

    index = tf.range(batch_n_nodes)
    index = (index - tf.gather(start_index, I)) + (I * max_n_nodes)
    dense = tf.zeros((n_graphs * max_n_nodes, feature_dim), dtype=X.dtype)
    dense = tf.tensor_scatter_nd_update(dense, index[..., None], X)

    batch = tf.reshape(dense, (n_graphs, max_n_nodes, feature_dim))

    return batch 
Example #2
Source File: delayedmodels.py    From spiking-net-tensorflow with GNU General Public License v3.0 6 votes vote down vote up
def update_active_spikes(self, spikes):
    """ Given some spikes, add them to active spikes with the appropraite delays
    
    Parameters:
      spikes (array like): The spikes that have just occured
      
    Returns:
      None
    """
    delays_some_hot = spikes * self.delays  # (100, 2)
    idxs = tf.where(tf.not_equal(delays_some_hot, 0))  # Will give indices of delays    (num_spikes * num_neurns, 2) elements are indices into delays_some_hot that are not 0
    just_delays = tf.gather_nd(delays_some_hot, idxs)  # These become the idx's in delay dimension? (after correction)  (num_spikes * num_neurns, 2) elements are delays (floats)
    
    # adjust for variable step size and circular array
    delay_dim_idxs = tf.reshape(self.spike_arrival_step(just_delays), [-1, 1])  # Okay now is the arrival index  (num_spikes * num_neurns, 1) elements are the correction step at which this spike will arrive
    
    full_idxs = tf.concat([idxs, delay_dim_idxs], axis=1)  # add delay indices as a column since they are an index and not more examples

    self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.ones(full_idxs.shape[0])) 
Example #3
Source File: embeddings.py    From mead-baseline with Apache License 2.0 6 votes vote down vote up
def encode(self, x):
        """Build a simple Lookup Table and set as input `x` if it exists, or `self.x` otherwise.

        :param x: An optional input sub-graph to bind to this operation or use `self.x` if `None`
        :return: The sub-graph output
        """
        self.x = x
        e0 = tf.tensor_scatter_nd_update(
            self.W, tf.constant(Offsets.PAD, dtype=tf.int32, shape=[1, 1]), tf.zeros(shape=[1, self.dsz])
        )
        with tf.control_dependencies([e0]):
            # The ablation table (4) in https://arxiv.org/pdf/1708.02182.pdf shows this has a massive impact
            embedding_w_dropout = self.drop(self.W, training=TRAIN_FLAG())
            word_embeddings = tf.nn.embedding_lookup(embedding_w_dropout, self.x)

        return word_embeddings 
Example #4
Source File: dataset.py    From tf2-yolo3 with Apache License 2.0 5 votes vote down vote up
def transform_targets_for_output(y_true, grid_y, grid_x, anchor_idxs, classes):
    # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
    N = tf.shape(y_true)[0]

    # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
    y_true_out = tf.zeros((N, grid_y, grid_x, tf.shape(anchor_idxs)[0], 6))

    anchor_idxs = tf.cast(anchor_idxs, tf.int32)

    indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
    idx = 0
    for i in tf.range(N):
        for j in tf.range(tf.shape(y_true)[1]):
            if tf.equal(y_true[i][j][2], 0):
                continue
            anchor_eq = tf.equal(anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))

            if tf.reduce_any(anchor_eq):
                box = y_true[i][j][0:4]
                box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2.

                anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
                grid_size = tf.cast(tf.stack([grid_x, grid_y], axis=-1), tf.float32)
                grid_xy = tf.cast(box_xy * grid_size, tf.int32)
                # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
                indexes = indexes.write(idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
                updates = updates.write(idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
                idx += 1

    y_ture_out = tf.tensor_scatter_nd_update(y_true_out, indexes.stack(), updates.stack())
    return y_ture_out 
Example #5
Source File: relgraphconv.py    From dgl with Apache License 2.0 5 votes vote down vote up
def basis_message_func(self, edges):
        """Message function for basis regularizer"""
        if self.num_bases < self.num_rels:
            # generate all weights from bases
            weight = tf.reshape(self.weight, (self.num_bases,
                                              self.in_feat * self.out_feat))
            weight = tf.reshape(tf.matmul(self.w_comp, weight), (
                self.num_rels, self.in_feat, self.out_feat))
        else:
            weight = self.weight

        # calculate msg @ W_r before put msg into edge
        # if src is th.int64 we expect it is an index select
        if edges.src['h'].dtype != tf.int64 and self.low_mem:
            etypes, _ = tf.unique(edges.data['type'])
            msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
            idx = tf.range(edges.src['h'].shape[0])
            for etype in etypes:
                loc = (edges.data['type'] == etype)
                w = weight[etype]
                src = tf.boolean_mask(edges.src['h'], loc)
                sub_msg = tf.matmul(src, w)
                indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
                msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
        else:
            msg = utils.bmm_maybe_select(
                edges.src['h'], weight, edges.data['type'])
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg} 
Example #6
Source File: relgraphconv.py    From dgl with Apache License 2.0 5 votes vote down vote up
def bdd_message_func(self, edges):
        """Message function for block-diagonal-decomposition regularizer"""
        if ((edges.src['h'].dtype == tf.int64) and
                len(edges.src['h'].shape) == 1):
            raise TypeError(
                'Block decomposition does not allow integer ID feature.')

        # calculate msg @ W_r before put msg into edge
        # if src is th.int64 we expect it is an index select
        if self.low_mem:
            etypes, _ = tf.unique(edges.data['type'])
            msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
            idx = tf.range(edges.src['h'].shape[0])
            for etype in etypes:
                loc = (edges.data['type'] == etype)
                w = tf.reshape(self.weight[etype],
                               (self.num_bases, self.submat_in, self.submat_out))
                src = tf.reshape(tf.boolean_mask(edges.src['h'], loc),
                                 (-1, self.num_bases, self.submat_in))
                sub_msg = tf.einsum('abc,bcd->abd', src, w)
                sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
                indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
                msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
        else:
            weight = tf.reshape(tf.gather(
                self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out))
            node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in))
            msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
        if 'norm' in edges.data:
            msg = msg * edges.data['norm']
        return {'msg': msg} 
Example #7
Source File: tensor.py    From dgl with Apache License 2.0 5 votes vote down vote up
def scatter_row(data, row_index, value):
    row_index = tf.expand_dims(row_index, 1)
    return tf.tensor_scatter_nd_update(data, row_index, value) 
Example #8
Source File: delayedmodels.py    From spiking-net-tensorflow with GNU General Public License v3.0 5 votes vote down vote up
def clear_current_active_spikes(self):
    """ Remove any spikes that arrived at the current time step
    
    Parameters:
      None
      
    Returns:
      None
    """
    # Fill in any 1's with zeros
    spike_idxs = tf.where(tf.not_equal(self.active_spikes[:, :, self.get_active_spike_idx()], 0) )
    full_idxs = tf.concat([spike_idxs, tf.ones((spike_idxs.shape[0], 1), dtype=tf.int64) * self.get_active_spike_idx()], axis=1)
    self.active_spikes = tf.tensor_scatter_nd_update(self.active_spikes, full_idxs, tf.zeros(full_idxs.shape[0])) 
Example #9
Source File: dataset.py    From DirectML with MIT License 5 votes vote down vote up
def transform_targets_for_output(y_true, grid_size, anchor_idxs):
    # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
    N = tf.shape(y_true)[0]

    # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
    y_true_out = tf.zeros(
        (N, grid_size, grid_size, tf.shape(anchor_idxs)[0], 6))

    anchor_idxs = tf.cast(anchor_idxs, tf.int32)

    indexes = tf.TensorArray(tf.int32, 1, dynamic_size=True)
    updates = tf.TensorArray(tf.float32, 1, dynamic_size=True)
    idx = 0
    for i in tf.range(N):
        for j in tf.range(tf.shape(y_true)[1]):
            if tf.equal(y_true[i][j][2], 0):
                continue
            anchor_eq = tf.equal(
                anchor_idxs, tf.cast(y_true[i][j][5], tf.int32))

            if tf.reduce_any(anchor_eq):
                box = y_true[i][j][0:4]
                box_xy = (y_true[i][j][0:2] + y_true[i][j][2:4]) / 2

                anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
                grid_xy = tf.cast(box_xy // (1/grid_size), tf.int32)

                # grid[y][x][anchor] = (tx, ty, bw, bh, obj, class)
                indexes = indexes.write(
                    idx, [i, grid_xy[1], grid_xy[0], anchor_idx[0][0]])
                updates = updates.write(
                    idx, [box[0], box[1], box[2], box[3], 1, y_true[i][j][4]])
                idx += 1

    # tf.print(indexes.stack())
    # tf.print(updates.stack())

    return tf.tensor_scatter_nd_update(
        y_true_out, indexes.stack(), updates.stack()) 
Example #10
Source File: layers.py    From deepchem with MIT License 4 votes vote down vote up
def call(self, inputs, training=True):
    """
    parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
    """
    atom_features = inputs[0]
    # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
    # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
    parents = tf.cast(inputs[1], dtype=tf.int32)
    # target atoms for each step: (batch_size*max_atoms) * max_atoms
    calculation_orders = inputs[2]
    calculation_masks = inputs[3]

    n_atoms = tf.squeeze(inputs[4])
    graph_features = tf.zeros((self.max_atoms * self.batch_size,
                               self.max_atoms + 1, self.n_graph_feat))

    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
      mask = calculation_masks[:, count]
      current_round = tf.boolean_mask(calculation_orders[:, count], mask)
      batch_atom_features = tf.gather(atom_features, current_round)

      # generating index for graph features used in the inputs
      stack1 = tf.reshape(
          tf.stack(
              [tf.boolean_mask(tf.range(n_atoms), mask)] * (self.max_atoms - 1),
              axis=1), [-1])
      stack2 = tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1])
      index = tf.stack([stack1, stack2], axis=1)
      # extracting graph features for parents of the target atoms, then flatten
      # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
      batch_graph_features = tf.reshape(
          tf.gather_nd(graph_features, index),
          [-1, (self.max_atoms - 1) * self.n_graph_feat])

      # concat into the input tensor: (batch_size*max_atoms) * n_inputs
      batch_inputs = tf.concat(
          axis=1, values=[batch_atom_features, batch_graph_features])
      # DAGgraph_step maps from batch_inputs to a batch of graph_features
      # of shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of target atoms in each graph
      batch_outputs = _DAGgraph_step(batch_inputs, self.W_list, self.b_list,
                                     self.activation_fn, self.dropouts,
                                     training)

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      target_index = tf.boolean_mask(target_index, mask)
      graph_features = tf.tensor_scatter_nd_update(graph_features, target_index,
                                                   batch_outputs)
    return batch_outputs 
Example #11
Source File: ops.py    From spektral with MIT License 4 votes vote down vote up
def segment_top_k(x, I, ratio, top_k_var):
    """
    Returns indices to get the top K values in x segment-wise, according to
    the segments defined in I. K is not fixed, but it is defined as a ratio of
    the number of elements in each segment.
    :param x: a rank 1 Tensor;
    :param I: a rank 1 Tensor with segment IDs for x;
    :param ratio: float, ratio of elements to keep for each segment;
    :param top_k_var: a tf.Variable created without shape validation (i.e.,
    `tf.Variable(0.0, validate_shape=False)`);
    :return: a rank 1 Tensor containing the indices to get the top K values of
    each segment in x.
    """
    I = tf.cast(I, tf.int32)
    num_nodes = tf.math.segment_sum(tf.ones_like(I), I)  # Number of nodes in each graph
    cumsum = tf.cumsum(num_nodes)  # Cumulative number of nodes (A, A+B, A+B+C)
    cumsum_start = cumsum - num_nodes  # Start index of each graph
    n_graphs = tf.shape(num_nodes)[0]  # Number of graphs in batch
    max_n_nodes = tf.reduce_max(num_nodes)  # Order of biggest graph in batch
    batch_n_nodes = tf.shape(I)[0]  # Number of overall nodes in batch
    to_keep = tf.math.ceil(ratio * tf.cast(num_nodes, tf.float32))
    to_keep = tf.cast(to_keep, I.dtype)  # Nodes to keep in each graph

    index = tf.range(batch_n_nodes)
    index = (index - tf.gather(cumsum_start, I)) + (I * max_n_nodes)

    y_min = tf.reduce_min(x)
    dense_y = tf.ones((n_graphs * max_n_nodes,))
    # subtract 1 to ensure that filler values do not get picked
    dense_y = dense_y * tf.cast(y_min - 1, dense_y.dtype)
    dense_y = tf.cast(dense_y, top_k_var.dtype)
    # top_k_var is a variable with unknown shape defined in the elsewhere
    top_k_var.assign(dense_y)
    dense_y = tf.tensor_scatter_nd_update(top_k_var, index[..., None], tf.cast(x, top_k_var.dtype))
    dense_y = tf.reshape(dense_y, (n_graphs, max_n_nodes))

    perm = tf.argsort(dense_y, direction='DESCENDING')
    perm = perm + cumsum_start[:, None]
    perm = tf.reshape(perm, (-1,))

    to_rep = tf.tile(tf.constant([1., 0.]), (n_graphs,))
    rep_times = tf.reshape(tf.concat((to_keep[:, None], (max_n_nodes - to_keep)[:, None]), -1), (-1,))
    mask = repeat(to_rep, rep_times)

    perm = tf.boolean_mask(perm, mask)

    return perm 
Example #12
Source File: scatter_elements.py    From onnx-tensorflow with Apache License 2.0 4 votes vote down vote up
def version_11(cls, node, **kwargs):
    axis = node.attrs.get("axis", 0)
    data = kwargs["tensor_dict"][node.inputs[0]]
    indices = kwargs["tensor_dict"][node.inputs[1]]
    updates = kwargs["tensor_dict"][node.inputs[2]]

    # poocess negative axis
    axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)

    # check are there any indices are out of bounds
    result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
    msg = 'ScatterElements indices are out of bounds, please double check the indices and retry.'
    with tf.control_dependencies(
        [tf.compat.v1.assert_equal(result, True, message=msg)]):
      # process negative indices
      indices = cls.process_neg_idx_along_axis(data, axis, indices)

      # Calculate shape of the tensorflow version of indices tensor.
      sparsified_dense_idx_shape = tf_shape(updates)

      # Move on to convert ONNX indices to tensorflow indices in 2 steps:
      #
      # Step 1:
      #   What would the index tensors look like if updates are all
      #   dense? In other words, produce a coordinate tensor for updates:
      #
      #   coordinate[i, j, k ...] = [i, j, k ...]
      #   where the shape of "coordinate" tensor is same as that of updates.
      #
      # Step 2:
      #   But the coordinate tensor needs some correction because coord
      #   vector at position axis is wrong (since we assumed update is dense,
      #   but it is not at the axis specified).
      #   So we update coordinate vector tensor elements at psotion=axis with
      #   the sparse coordinate indices.

      idx_tensors_per_axis = tf.meshgrid(
          *list(
              map(lambda x: tf.range(x, dtype=tf.dtypes.int64),
                  sparsified_dense_idx_shape)),
          indexing='ij')
      idx_tensors_per_axis[axis] = indices
      dim_expanded_idx_tensors_per_axis = list(
          map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
      coordinate = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)

      # Now the coordinate tensor is in the shape
      # [updates.shape, updates.rank]
      # we need it to flattened into the shape:
      # [product(updates.shape), updates.rank]
      indices = tf.reshape(coordinate, [-1, tf.rank(data)])
      updates = tf.reshape(updates, [-1])

      return [tf.tensor_scatter_nd_update(data, indices, updates)] 
Example #13
Source File: gather_elements.py    From onnx-tensorflow with Apache License 2.0 4 votes vote down vote up
def version_11(cls, node, **kwargs):
    # GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis that identifies
    # an axis of data (by default, the outer-most axis, that is axis 0). It is an indexing operation that produces its output by
    # indexing into the input data tensor at index positions determined by elements of the indices tensor. Its output shape is the
    # same as the shape of indices and consists of one value (gathered from the data) for each element in indices.

    axis = node.attrs.get("axis", 0)
    data = kwargs["tensor_dict"][node.inputs[0]]
    indices = kwargs["tensor_dict"][node.inputs[1]]

    # poocess negative axis
    axis = axis if axis >= 0 else tf.add(tf.rank(data), axis)

    # check are there any indices are out of bounds
    result = cls.chk_idx_out_of_bounds_along_axis(data, axis, indices)
    msg = 'GatherElements indices are out of bounds,'\
      ' please double check the indices and retry.'
    with tf.control_dependencies(
        [tf.compat.v1.assert_equal(result, True, message=msg)]):
      # process negative indices
      indices = cls.process_neg_idx_along_axis(data, axis, indices)

      # adapted from reference implementation in onnx/onnx/backend/test/case/node/gatherelements.py
      if axis == 0:
        axis_perm = tf.range(tf.rank(data))
        data_swaped = data
        index_swaped = indices
      else:
        axis_perm = tf.tensor_scatter_nd_update(tf.range(tf.rank(data)),
                                                tf.constant([[0], [axis]]),
                                                tf.constant([axis, 0]))
        data_swaped = tf.transpose(data, perm=axis_perm)
        index_swaped = tf.transpose(indices, perm=axis_perm)

      idx_tensors_per_axis = tf.meshgrid(*list(
          map(lambda x: tf.range(x, dtype=index_swaped.dtype),
              index_swaped.shape.as_list())),
                                        indexing='ij')
      idx_tensors_per_axis[0] = index_swaped
      dim_expanded_idx_tensors_per_axis = list(
          map(lambda x: tf.expand_dims(x, axis=-1), idx_tensors_per_axis))
      index_expanded = tf.concat(dim_expanded_idx_tensors_per_axis, axis=-1)

      gathered = tf.gather_nd(data_swaped, index_expanded)
      y = tf.transpose(gathered, perm=axis_perm)

      return [y] 
Example #14
Source File: metric_utils.py    From ULTRA with Apache License 2.0 4 votes vote down vote up
def scatter_to_2d(tensor, segments, pad_value, output_shape=None):
    """Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`.

    For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then
    the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred
    when None is provided. In this case, the shape will be dynamic and may not be
    compatible with TPU. For TPU use case, please provide the `output_shape`
    explicitly.

    Args:
      tensor: A 1-D numeric `Tensor`.
      segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0,
        0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted.
      pad_value: A numeric value to pad the output `Tensor`.
      output_shape: A `Tensor` of size 2 telling the desired shape of the output
        tensor. If None, the output_shape will be inferred and not fixed at
        compilation time. When output_shape is smaller than needed, trucation will
        be applied.

    Returns:
      A 2-D Tensor.
    """
    with tf.compat.v1.name_scope(name='scatter_to_2d'):
        tensor = tf.convert_to_tensor(value=tensor)
        segments = tf.convert_to_tensor(value=segments)
        tensor.get_shape().assert_has_rank(1)
        segments.get_shape().assert_has_rank(1)
        tensor.get_shape().assert_is_compatible_with(segments.get_shape())

        # Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so
        # that we can use scatter_nd to distribute the value in `tensor` to 2-D. The
        # needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the
        # in-segment indices.
        index_2nd_dim = _in_segment_indices(segments)

        # Compute the output_shape.
        if output_shape is None:
            # Set output_shape to the inferred one.
            output_shape = [
                tf.reduce_max(input_tensor=segments) + 1,
                tf.reduce_max(input_tensor=index_2nd_dim) + 1
            ]
        else:
            # The output_shape may be smaller. We collapse the out-of-range ones into
            # indices [output_shape[0], 0] and then use tf.slice to remove extra row
            # and column after scatter.
            valid_segments = tf.less(segments, output_shape[0])
            valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1])
            mask = tf.logical_and(valid_segments, valid_2nd_dim)
            segments = tf.compat.v1.where(mask, segments,
                                          output_shape[0] * tf.ones_like(segments))
            index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim,
                                               tf.zeros_like(index_2nd_dim))
        # Create the 2D Tensor. For padding, we add one extra row and column and
        # then slice them to fit the output_shape.
        nd_indices = tf.stack([segments, index_2nd_dim], axis=1)
        padding = pad_value * tf.ones(
            shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype)
        tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor)
        tensor = tf.slice(tensor, begin=[0, 0], size=output_shape)
        return tensor 
Example #15
Source File: utils.py    From ranking with Apache License 2.0 4 votes vote down vote up
def scatter_to_2d(tensor, segments, pad_value, output_shape=None):
  """Scatters a flattened 1-D `tensor` to 2-D with padding based on `segments`.

  For example: tensor = [1, 2, 3], segments = [0, 1, 0] and pad_value = -1, then
  the returned 2-D tensor is [[1, 3], [2, -1]]. The output_shape is inferred
  when None is provided. In this case, the shape will be dynamic and may not be
  compatible with TPU. For TPU use case, please provide the `output_shape`
  explicitly.

  Args:
    tensor: A 1-D numeric `Tensor`.
    segments: A 1-D int `Tensor` which is the idx output from tf.unique like [0,
      0, 1, 0, 2]. See tf.unique. The segments may or may not be sorted.
    pad_value: A numeric value to pad the output `Tensor`.
    output_shape: A `Tensor` of size 2 telling the desired shape of the output
      tensor. If None, the output_shape will be inferred and not fixed at
      compilation time. When output_shape is smaller than needed, trucation will
      be applied.

  Returns:
    A 2-D Tensor.
  """
  with tf.compat.v1.name_scope(name='scatter_to_2d'):
    tensor = tf.convert_to_tensor(value=tensor)
    segments = tf.convert_to_tensor(value=segments)
    tensor.get_shape().assert_has_rank(1)
    segments.get_shape().assert_has_rank(1)
    tensor.get_shape().assert_is_compatible_with(segments.get_shape())

    # Say segments = [0, 0, 0, 1, 2, 2]. We would like to build the 2nd dim so
    # that we can use scatter_nd to distribute the value in `tensor` to 2-D. The
    # needed 2nd dim for this case is [0, 1, 2, 0, 0, 1], which is the
    # in-segment indices.
    index_2nd_dim = _in_segment_indices(segments)

    # Compute the output_shape.
    if output_shape is None:
      # Set output_shape to the inferred one.
      output_shape = [
          tf.reduce_max(input_tensor=segments) + 1,
          tf.reduce_max(input_tensor=index_2nd_dim) + 1
      ]
    else:
      # The output_shape may be smaller. We collapse the out-of-range ones into
      # indices [output_shape[0], 0] and then use tf.slice to remove extra row
      # and column after scatter.
      valid_segments = tf.less(segments, output_shape[0])
      valid_2nd_dim = tf.less(index_2nd_dim, output_shape[1])
      mask = tf.logical_and(valid_segments, valid_2nd_dim)
      segments = tf.compat.v1.where(mask, segments,
                                    output_shape[0] * tf.ones_like(segments))
      index_2nd_dim = tf.compat.v1.where(mask, index_2nd_dim,
                                         tf.zeros_like(index_2nd_dim))
    # Create the 2D Tensor. For padding, we add one extra row and column and
    # then slice them to fit the output_shape.
    nd_indices = tf.stack([segments, index_2nd_dim], axis=1)
    padding = pad_value * tf.ones(
        shape=(output_shape + tf.ones_like(output_shape)), dtype=tensor.dtype)
    tensor = tf.tensor_scatter_nd_update(padding, nd_indices, tensor)
    tensor = tf.slice(tensor, begin=[0, 0], size=output_shape)
    return tensor 
Example #16
Source File: box_utils.py    From ssd-tf2 with MIT License 4 votes vote down vote up
def compute_target(default_boxes, gt_boxes, gt_labels, iou_threshold=0.5):
    """ Compute regression and classification targets
    Args:
        default_boxes: tensor (num_default, 4)
                       of format (cx, cy, w, h)
        gt_boxes: tensor (num_gt, 4)
                  of format (xmin, ymin, xmax, ymax)
        gt_labels: tensor (num_gt,)
    Returns:
        gt_confs: classification targets, tensor (num_default,)
        gt_locs: regression targets, tensor (num_default, 4)
    """
    # Convert default boxes to format (xmin, ymin, xmax, ymax)
    # in order to compute overlap with gt boxes
    transformed_default_boxes = transform_center_to_corner(default_boxes)
    iou = compute_iou(transformed_default_boxes, gt_boxes)

    best_gt_iou = tf.math.reduce_max(iou, 1)
    best_gt_idx = tf.math.argmax(iou, 1)

    best_default_iou = tf.math.reduce_max(iou, 0)
    best_default_idx = tf.math.argmax(iou, 0)

    best_gt_idx = tf.tensor_scatter_nd_update(
        best_gt_idx,
        tf.expand_dims(best_default_idx, 1),
        tf.range(best_default_idx.shape[0], dtype=tf.int64))

    # Normal way: use a for loop
    # for gt_idx, default_idx in enumerate(best_default_idx):
    #     best_gt_idx = tf.tensor_scatter_nd_update(
    #         best_gt_idx,
    #         tf.expand_dims([default_idx], 1),
    #         [gt_idx])

    best_gt_iou = tf.tensor_scatter_nd_update(
        best_gt_iou,
        tf.expand_dims(best_default_idx, 1),
        tf.ones_like(best_default_idx, dtype=tf.float32))

    gt_confs = tf.gather(gt_labels, best_gt_idx)
    gt_confs = tf.where(
        tf.less(best_gt_iou, iou_threshold),
        tf.zeros_like(gt_confs),
        gt_confs)

    gt_boxes = tf.gather(gt_boxes, best_gt_idx)
    gt_locs = encode(default_boxes, gt_boxes)

    return gt_confs, gt_locs