Python torch.argmin() Examples

The following are 30 code examples of torch.argmin(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: min_norm_solver.py    From Hydra with MIT License 6 votes vote down vote up
def forward(self, grammian):
        """Planar case solver, when Vi lies on the same plane

        Args:
          grammian: grammian matrix G[i, j] = [<Vi, Vj>], G is a nxn tensor

        Returns:
          sol: coefficients c = [c1, ... cn] that solves the min-norm problem
        """
        vivj = grammian[self.ii_triu, self.jj_triu]
        vivi = grammian[self.ii_triu, self.ii_triu]
        vjvj = grammian[self.jj_triu, self.jj_triu]

        gamma, cost = self.line_solver_vectorized(vivi, vivj, vjvj)
        offset = torch.argmin(cost)
        i_min, j_min = self.i_triu[offset], self.j_triu[offset]
        sol = torch.zeros(self.n, device=grammian.device)
        sol[i_min], sol[j_min] = gamma[offset], 1. - gamma[offset]
        return sol 
Example #2
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        centroids = F.dropout(self.centroids, p=0.3)
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        print('Distances:', distances[:3])
        if not greedy:
            resp = - .5 * self.tau * distances - self.reduce_dim / 2 * math.log(2 * math.pi * self.tau) + torch.log(self.prior)
        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #3
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        print('Distances:', distances[:3])
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #4
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #5
Source File: qv.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data  # the only diff from 16
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #6
Source File: qv_v1.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        points = points.data
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #7
Source File: loss.py    From learnable-triangulation-pytorch with MIT License 6 votes vote down vote up
def forward(self, coord_volumes_batch, volumes_batch_pred, keypoints_gt, keypoints_binary_validity):
        loss = 0.0
        n_losses = 0

        batch_size = volumes_batch_pred.shape[0]
        for batch_i in range(batch_size):
            coord_volume = coord_volumes_batch[batch_i]
            keypoints_gt_i = keypoints_gt[batch_i]

            coord_volume_unsq = coord_volume.unsqueeze(0)
            keypoints_gt_i_unsq = keypoints_gt_i.unsqueeze(1).unsqueeze(1).unsqueeze(1)

            dists = torch.sqrt(((coord_volume_unsq - keypoints_gt_i_unsq) ** 2).sum(-1))
            dists = dists.view(dists.shape[0], -1)

            min_indexes = torch.argmin(dists, dim=-1).detach().cpu().numpy()
            min_indexes = np.stack(np.unravel_index(min_indexes, volumes_batch_pred.shape[-3:]), axis=1)

            for joint_i, index in enumerate(min_indexes):
                validity = keypoints_binary_validity[batch_i, joint_i]
                loss += validity[0] * (-torch.log(volumes_batch_pred[batch_i, joint_i, index[0], index[1], index[2]] + 1e-6))
                n_losses += 1


        return loss / n_losses 
Example #8
Source File: nn.py    From Distributional-Signatures with MIT License 6 votes vote down vote up
def forward(self, XS, YS, XQ, YQ):
        '''
            @param XS (support x): support_size x ebd_dim
            @param YS (support y): support_size
            @param XQ (support x): query_size x ebd_dim
            @param YQ (support y): query_size

            @return acc
            @return None (a placeholder for loss)
        '''
        if self.args.nn_distance == 'l2':
            dist = self._compute_l2(XS, XQ)
        elif self.args.nn_distance == 'cos':
            dist = self._compute_cos(XS, XQ)
        else:
            raise ValueError("nn_distance can only be l2 or cos.")

        # 1-NearestNeighbour
        nn_idx = torch.argmin(dist, dim=1)
        pred = YS[nn_idx]

        acc = torch.mean((pred == YQ).float()).item()

        return acc, None 
Example #9
Source File: qv_v1.py    From attn2d with MIT License 6 votes vote down vote up
def assign(self, points, distance='euclid', greedy=False):
        # points = points.data  # the only diff from 16
        centroids = self.centroids
        if distance == 'cosine':
            # nearest neigbor in the centroids (cosine distance):
            points = F.normalize(points, dim=-1)
            centroids = F.normalize(centroids, dim=-1)

        distances = (torch.sum(points**2, dim=1, keepdim=True) +
                     torch.sum(centroids**2, dim=1, keepdim=True).t() -
                     2 * torch.matmul(points, centroids.t()))  # T*B, e
        if not greedy:
            logits = - distances
            resp = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
            # batch_counts = resp.sum(dim=0).view(-1).data

        else:
            # Greedy non-differentiable responsabilities:
            indices = torch.argmin(distances, dim=-1)  # T*B
            resp = torch.zeros(points.size(0), self.ne).type_as(points)
            resp.scatter_(1, indices.unsqueeze(1), 1)
        return resp 
Example #10
Source File: fast_gaussian.py    From torchsupport with MIT License 6 votes vote down vote up
def kmeans(input, n_clusters=16, tol=1e-6):
  """
  TODO: check correctness
  """
  indices = torch.Tensor(np.random.choice(input.size(-1), n_clusters))
  values = input[:, :, indices]

  while True:
    dist = func.pairwise_distance(
      input.unsqueeze(2).expand(-1, -1, values.size(2), input.size(2)).reshape(
        input.size(0), input.size(1), input.size(2) * values.size(2)),
      values.unsqueeze(3).expand(-1, -1, values.size(2), input.size(2)).reshape(
        input.size(0), input.size(1), input.size(2) * values.size(2))
    )
    choice_cluster = torch.argmin(dist, dim=1)
    old_values = values
    values = input[choice_cluster.nonzeros()]
    shift = (old_values - values).norm(dim=1)
    if shift.max() ** 2 < tol:
      break

  return values 
Example #11
Source File: tm_util.py    From openprotein with MIT License 6 votes vote down vote up
def calculate_partitions(partitions_count, cluster_partitions, types):
    partition_distribution = torch.ones((partitions_count,
                                         len(torch.unique(types))),
                                        dtype=torch.long)
    partition_assignments = torch.zeros(cluster_partitions.shape[0],
                                        dtype=torch.long)

    for i in torch.unique(cluster_partitions):
        cluster_positions = (cluster_partitions == i).nonzero()
        cluster_types = types[cluster_positions]
        unique_types_in_cluster, type_count = torch.unique(cluster_types, return_counts=True)
        tmp_distribution = partition_distribution.clone()
        tmp_distribution[:, unique_types_in_cluster] += type_count
        relative_distribution = partition_distribution.double() / tmp_distribution.double()
        min_relative_distribution_group = torch.argmin(torch.sum(relative_distribution, dim=1))
        partition_distribution[min_relative_distribution_group,
                               unique_types_in_cluster] += type_count
        partition_assignments[cluster_positions] = min_relative_distribution_group

    write_out("Loaded data into the following partitions")
    write_out("[[  TM  SP+TM  SP Glob]")
    write_out(partition_distribution - torch.ones(partition_distribution.shape,
                                                  dtype=torch.long))
    return partition_assignments 
Example #12
Source File: operations.py    From NNEF-Tools with Apache License 2.0 6 votes vote down vote up
def _nnef_argminmax_reduce(input, axes, argmin=False):
    # type:(torch.Tensor, List[int], bool)->torch.Tensor
    if len(axes) == 1:
        return _nnef_generic_reduce(input=input, axes=axes, f=torch.argmin if argmin else torch.argmax)
    else:
        axes = sorted(axes)
        consecutive_axes = list(range(axes[0], axes[0] + len(axes)))
        if axes == consecutive_axes:
            reshaped = nnef_reshape(input,
                                    shape=(list(input.shape)[:axes[0]]
                                           + [-1]
                                           + list(input.shape[axes[0] + len(axes):])))
            reduced = _nnef_generic_reduce(input=reshaped, axes=[axes[0]], f=torch.argmin if argmin else torch.argmax)
            reshaped = nnef_reshape(reduced, shape=list(dim if axis not in axes else 1
                                                        for axis, dim in enumerate(input.shape)))
            return reshaped
        else:
            raise utils.NNEFToolsException(
                "{} is only implemented for consecutive axes.".format("argmin_reduce" if argmin else "argmax_reduce")) 
Example #13
Source File: em.py    From attn2d with MIT License 5 votes vote down vote up
def resolve_empty_clusters(self):
        """
        If one cluster is empty, the most populated cluster is split into
        two clusters by shifting the respective centroids. This is done
        iteratively for a fixed number of tentatives.
        """

        # empty clusters
        counts = Counter(map(lambda x: x.item(), self.assignments))
        empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
        n_empty_clusters = len(empty_clusters)

        tentatives = 0
        while len(empty_clusters) > 0:
            # given an empty cluster, find most populated cluster and split it into two
            k = random.choice(list(empty_clusters))
            m = counts.most_common(1)[0][0]
            e = torch.randn_like(self.centroids[m]) * self.eps
            self.centroids[k] = self.centroids[m].clone()
            self.centroids[k] += e
            self.centroids[m] -= e

            # recompute assignments
            distances = self.compute_distances()  # (n_centroids x out_features)
            self.assignments = torch.argmin(distances, dim=0)  # (out_features)

            # check for empty clusters
            counts = Counter(map(lambda x: x.item(), self.assignments))
            empty_clusters = set(range(self.n_centroids)) - set(counts.keys())

            # increment tentatives
            if tentatives == self.max_tentatives:
                logging.info(
                    f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
                )
                raise EmptyClusterResolveError
            tentatives += 1

        return n_empty_clusters 
Example #14
Source File: match_segmentation.py    From PlanarReconstruction with MIT License 5 votes vote down vote up
def forward(self, segmentation, prob, gt_instance, gt_plane_num):
        """
        greedy matching
        match segmentation with ground truth instance 
        :param segmentation: tensor with size (N, K)
        :param prob: tensor with size (N, 1)
        :param gt_instance: tensor with size (21, h, w)
        :param gt_plane_num: int
        :return: a (K, 1) long tensor indicate closest ground truth instance id, start from 0
        """

        n, k = segmentation.size()
        _, h, w = gt_instance.size()
        assert (prob.size(0) == n and h*w  == n)
        
        # ingnore non planar region
        gt_instance = gt_instance[:gt_plane_num, :, :].view(1, -1, h*w)     # (1, gt_plane_num, h*w)

        segmentation = segmentation.t().view(k, 1, h*w)                     # (k, 1, h*w)

        # calculate instance wise cross entropy matrix (K, gt_plane_num)
        gt_instance = gt_instance.type(torch.float32)

        ce_loss = - (gt_instance * torch.log(segmentation + 1e-6) +
            (1-gt_instance) * torch.log(1-segmentation + 1e-6))             # (k, gt_plane_num, k*w)

        ce_loss = torch.mean(ce_loss, dim=2)                                # (k, gt_plane_num)
        
        matching = torch.argmin(ce_loss, dim=1, keepdim=True)

        return matching 
Example #15
Source File: em.py    From attn2d with MIT License 5 votes vote down vote up
def step(self, i):
        """
        There are two standard steps for each iteration: expectation (E) and
        minimization (M). The E-step (assignment) is performed with an exhaustive
        search and the M-step (centroid computation) is performed with
        the exact solution.

        Args:
            - i: step number

        Remarks:
            - The E-step heavily uses PyTorch broadcasting to speed up computations
              and reduce the memory overhead
        """

        # assignments (E-step)
        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features)
        n_empty_clusters = self.resolve_empty_clusters()

        # centroids (M-step)
        for k in range(self.n_centroids):
            W_k = self.W[:, self.assignments == k]  # (in_features x size_of_cluster_k)
            self.centroids[k] = W_k.mean(dim=1)  # (in_features)

        # book-keeping
        obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
        self.objective.append(obj)
        if self.verbose:
            logging.info(
                f"Iteration: {i},\t"
                f"objective: {obj:.6f},\t"
                f"resolved empty clusters: {n_empty_clusters}"
            ) 
Example #16
Source File: em.py    From attn2d with MIT License 5 votes vote down vote up
def assign(self):
        """
        Assigns each column of W to its closest centroid, thus essentially
        performing the E-step in train().

        Remarks:
            - The function must be called after train() or after loading
              centroids using self.load(), otherwise it will return empty tensors
        """

        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features) 
Example #17
Source File: losses.py    From geoseg with MIT License 5 votes vote down vote up
def unravel_index(tensor, cols):
        """
        args:
            tensor : 2D tensor, [nb, rows*cols]
            cols : int
        return 2D tensor nb * [rowIndex, colIndex]
        """
        index = torch.argmin(tensor, dim=1).view(-1,1)
        rIndex = index / cols
        cIndex = index % cols
        minRC = torch.cat([rIndex, cIndex], dim=1)
        # print("minRC", minRC.shape, minRC)
        return minRC 
Example #18
Source File: time_profilers.py    From ignite with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(), torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(), torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out) 
Example #19
Source File: operations.py    From NNEF-Tools with Apache License 2.0 5 votes vote down vote up
def nnef_argmax_reduce(input, axes):
    # type:(torch.Tensor, List[int])->torch.Tensor
    return _nnef_argminmax_reduce(input, axes, argmin=False) 
Example #20
Source File: operations.py    From NNEF-Tools with Apache License 2.0 5 votes vote down vote up
def nnef_argmin_reduce(input, axes):
    # type:(torch.Tensor, List[int])->torch.Tensor
    return _nnef_argminmax_reduce(input, axes, argmin=True) 
Example #21
Source File: filters.py    From BCAI_kaggle_CHAMPS with MIT License 5 votes vote down vote up
def subgraph_filter(x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, args):
    D = sqdist(x_atom_pos[:,:,:3], x_atom_pos[:,:,:3])
    x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle = \
        x_atom.clone().detach(), x_atom_pos.clone().detach(), x_bond.clone().detach(), x_bond_dist.clone().detach(), x_triplet.clone().detach(), x_triplet_angle.clone().detach()
    bsz = x_atom.shape[0]
    bonds_mask = torch.ones(bsz, x_bond.shape[1], 1).to(x_atom.device)
    for mol_id in range(bsz):
        if np.random.uniform(0,1) > args.cutout:
            continue
        assert not args.use_quad, "Quads are NOT cut out yet"
        atom_dists = D[mol_id]
        atoms = x_atom[mol_id, :, 0]
        n_valid_atoms = (atoms > 0).sum().item()
        if n_valid_atoms < 10:
            continue
        idx_to_drop = np.random.randint(n_valid_atoms-1)
        dist_row = atom_dists[idx_to_drop]
        neighbor_to_drop = torch.argmin((dist_row[dist_row>0])[:n_valid_atoms-1]).item()
        if neighbor_to_drop >= idx_to_drop: 
            neighbor_to_drop += 1
        x_atom[mol_id, idx_to_drop] = 0
        x_atom[mol_id, neighbor_to_drop] = 0
        x_atom_pos[mol_id, idx_to_drop] = 0
        x_atom_pos[mol_id, neighbor_to_drop] = 0
        bond_pos_to_drop = (x_bond[mol_id, :, 3] == idx_to_drop) | (x_bond[mol_id, :, 3] == neighbor_to_drop) \
                         | (x_bond[mol_id, :, 4] == idx_to_drop) | (x_bond[mol_id, :, 4] == neighbor_to_drop)
        trip_pos_to_drop = (x_triplet[mol_id, :, 2] == idx_to_drop) | (x_triplet[mol_id, :, 2] == neighbor_to_drop) \
                         | (x_triplet[mol_id, :, 3] == idx_to_drop) | (x_triplet[mol_id, :, 3] == neighbor_to_drop) \
                         | (x_triplet[mol_id, :, 4] == idx_to_drop) | (x_triplet[mol_id, :, 4] == neighbor_to_drop)
        x_bond[mol_id, bond_pos_to_drop] = 0
        x_bond_dist[mol_id, bond_pos_to_drop] = 0
        bonds_mask[mol_id, bond_pos_to_drop] = 0
        x_triplet[mol_id, trip_pos_to_drop] = 0
        x_triplet_angle[mol_id, trip_pos_to_drop] = 0
    return x_atom, x_atom_pos, x_bond, x_bond_dist, x_triplet, x_triplet_angle, bonds_mask 
Example #22
Source File: utils.py    From gpytorch with MIT License 5 votes vote down vote up
def least_used_cuda_device() -> Generator:
    """Contextmanager for automatically selecting the cuda device
    with the least allocated memory"""
    mem_allocs = get_cuda_max_memory_allocations()
    least_used_device = torch.argmin(mem_allocs).item()
    with torch.cuda.device(least_used_device):
        yield 
Example #23
Source File: elastic_quant_connect.py    From Pytorch_Quantize_impls with MIT License 5 votes vote down vote up
def _proj_val(x, set):
    """
    Compute the projection from x onto the set given.

    :param x: Input pytorch tensor.
    :param set: Input pytorch vector used to perform the projection.
    """
    x = x.repeat((set.size()[0],)+(1,)*len(x.size()))
    x = x.permute(*(tuple(range(len(x.size())))[1:]  +(0,) ))
    x = torch.abs(x-set)
    x = torch.argmin(x, dim=len(x.size())-1, keepdim=False)
    return set[x] 
Example #24
Source File: min_norm_solver.py    From Hydra with MIT License 5 votes vote down vote up
def forward(self, vecs):
        """Computes grammian matrix G_{i,j} = (<v_i, v_j>)_{i,j}.
        """
        if self.n_tasks == 1:
            return vecs[0]
        if self.n_tasks == 2:
            v1v1 = torch.dot(vecs[0], vecs[0])
            v1v2 = torch.dot(vecs[0], vecs[1])
            v2v2 = torch.dot(vecs[1], vecs[1])
            gamma = self.line_solver(v1v1, v1v2, v2v2)
            return gamma * vecs[0] + (1. - gamma) * vecs[1]

        self.sol.fill_(1. / self.n)
        self.new_sol.copy_(self.sol)
        torch.mm(vecs, vecs.t(), out=self.grammian)

        for iter_count in range(self.MAX_ITER):
            gram_dot_sol = torch.mv(self.grammian, self.sol)
            t_iter = torch.argmin(gram_dot_sol)

            v1v1 = torch.dot(self.sol, gram_dot_sol)
            v1v2 = torch.dot(self.sol, self.grammian[:, t_iter])
            v2v2 = self.grammian[t_iter, t_iter]

            gamma = self.line_solver(v1v1, v1v2, v2v2)
            self.new_sol *= gamma
            self.new_sol[t_iter] += 1. - gamma

            change = self.new_sol - self.sol
            if torch.sum(torch.abs(change)) < self.STOP_CRIT:
                return self.new_sol
            self.sol.copy_(self.new_sol)
        return self.sol 
Example #25
Source File: memory.py    From adeptRL with GNU General Public License v3.0 5 votes vote down vote up
def append(self, new_k, new_v):
        """
        :param new_k: expecting a vector of dimensionality [Num Key Chan]
        :param new_v: expecting a vector of dimensionality [Num Value Chan]
        :return:
        """
        min_idx = torch.argmin(self.weight_buff).item()
        self.keys[min_idx, :] = new_k
        self.values[min_idx, :] = new_v
        self.weight_buff[min_idx] = torch.mean(self.weight_buff) 
Example #26
Source File: memory.py    From adeptRL with GNU General Public License v3.0 5 votes vote down vote up
def append(self, new_k, new_v):
        min_idx = torch.argmin(self.weight_buff).item()
        self.keys[min_idx, :] = new_k
        self.values[min_idx, :] = new_v
        self.weight_buff[min_idx] = torch.mean(self.weight_buff) 
Example #27
Source File: em.py    From fairseq with MIT License 5 votes vote down vote up
def assign(self):
        """
        Assigns each column of W to its closest centroid, thus essentially
        performing the E-step in train().

        Remarks:
            - The function must be called after train() or after loading
              centroids using self.load(), otherwise it will return empty tensors
        """

        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features) 
Example #28
Source File: em.py    From fairseq with MIT License 5 votes vote down vote up
def resolve_empty_clusters(self):
        """
        If one cluster is empty, the most populated cluster is split into
        two clusters by shifting the respective centroids. This is done
        iteratively for a fixed number of tentatives.
        """

        # empty clusters
        counts = Counter(map(lambda x: x.item(), self.assignments))
        empty_clusters = set(range(self.n_centroids)) - set(counts.keys())
        n_empty_clusters = len(empty_clusters)

        tentatives = 0
        while len(empty_clusters) > 0:
            # given an empty cluster, find most populated cluster and split it into two
            k = random.choice(list(empty_clusters))
            m = counts.most_common(1)[0][0]
            e = torch.randn_like(self.centroids[m]) * self.eps
            self.centroids[k] = self.centroids[m].clone()
            self.centroids[k] += e
            self.centroids[m] -= e

            # recompute assignments
            distances = self.compute_distances()  # (n_centroids x out_features)
            self.assignments = torch.argmin(distances, dim=0)  # (out_features)

            # check for empty clusters
            counts = Counter(map(lambda x: x.item(), self.assignments))
            empty_clusters = set(range(self.n_centroids)) - set(counts.keys())

            # increment tentatives
            if tentatives == self.max_tentatives:
                logging.info(
                    f"Could not resolve all empty clusters, {len(empty_clusters)} remaining"
                )
                raise EmptyClusterResolveError
            tentatives += 1

        return n_empty_clusters 
Example #29
Source File: em.py    From fairseq with MIT License 5 votes vote down vote up
def step(self, i):
        """
        There are two standard steps for each iteration: expectation (E) and
        minimization (M). The E-step (assignment) is performed with an exhaustive
        search and the M-step (centroid computation) is performed with
        the exact solution.

        Args:
            - i: step number

        Remarks:
            - The E-step heavily uses PyTorch broadcasting to speed up computations
              and reduce the memory overhead
        """

        # assignments (E-step)
        distances = self.compute_distances()  # (n_centroids x out_features)
        self.assignments = torch.argmin(distances, dim=0)  # (out_features)
        n_empty_clusters = self.resolve_empty_clusters()

        # centroids (M-step)
        for k in range(self.n_centroids):
            W_k = self.W[:, self.assignments == k]  # (in_features x size_of_cluster_k)
            self.centroids[k] = W_k.mean(dim=1)  # (in_features)

        # book-keeping
        obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item()
        self.objective.append(obj)
        if self.verbose:
            logging.info(
                f"Iteration: {i},\t"
                f"objective: {obj:.6f},\t"
                f"resolved empty clusters: {n_empty_clusters}"
            ) 
Example #30
Source File: fsaf_head.py    From Feature-Selective-Anchor-Free-Module-for-Single-Shot-Object-Detection with Apache License 2.0 4 votes vote down vote up
def get_online_pyramid_level(self, cls_scores_img, bbox_preds_img, gt_bbox_obj_xyxy, gt_label_obj):
        device = cls_scores_img[0].device
        num_levels = len(cls_scores_img)
        level_losses = torch.zeros(num_levels)
        for level in range(num_levels):
            H,W = cls_scores_img[level].shape[1:]
            b_p_xyxy = gt_bbox_obj_xyxy / self.feat_strides[level]
            b_e_xyxy = self.get_prop_xyxy(b_p_xyxy, self.eps_e, W, H)
            
            # Eqn-(1)
            N = (b_e_xyxy[3]-b_e_xyxy[1]+1) * (b_e_xyxy[2]-b_e_xyxy[0]+1)
            
            # cls loss; FL
            score = cls_scores_img[level][gt_label_obj,b_e_xyxy[1]:b_e_xyxy[3]+1,b_e_xyxy[0]:b_e_xyxy[2]+1]
            score = score.contiguous().view(-1).unsqueeze(1)
            label = torch.ones_like(score).long()
            label = label.contiguous().view(-1)
            
            loss_cls = sigmoid_focal_loss(score, label, gamma=self.FL_gamma, alpha=self.FL_alpha, reduction='mean')
            #loss_cls /= N
            
            # reg loss; IoU
            offsets = bbox_preds_img[level][:,b_e_xyxy[1]:b_e_xyxy[3]+1,b_e_xyxy[0]:b_e_xyxy[2]+1]
            offsets = offsets.contiguous().permute(1,2,0)  # (b_e_H,b_e_W,4)
            offsets = offsets.reshape(-1,4) # (#pix-e,4)
            
            # predicted bbox
            y,x = torch.meshgrid([torch.arange(b_e_xyxy[1],b_e_xyxy[3]+1), torch.arange(b_e_xyxy[0],b_e_xyxy[2]+1)])
            y = (y.float() + 0.5) * self.feat_strides[level]
            x = (x.float() + 0.5) * self.feat_strides[level]
            xy = torch.cat([x.unsqueeze(2),y.unsqueeze(2)], dim=2).float().to(device)
            xy = xy.reshape(-1,2)
            
            dist_pred = offsets * self.feat_strides[level]
            bboxes = self.dist2bbox(xy, dist_pred, self.bbox_offset_norm)
            
            loss_reg = iou_loss(bboxes, gt_bbox_obj_xyxy.unsqueeze(0).repeat(N,1), reduction='mean')
            #loss_cls /= N
            
            loss = loss_cls + loss_reg
            
            level_losses[level] = loss
        min_level_idx = torch.argmin(level_losses)
        #print(level_losses, min_level_idx)
        return min_level_idx