import math import torch import torch.nn.functional as F from .. import utils class ObjectEncoder(object): def __init__(self, classnames=['Car'], pos_std=[.5, .36, .5], log_dim_mean=[[0.42, 0.48, 1.35]], log_dim_std=[[.085, .067, .115]], sigma=1., nms_thresh=0.05): self.classnames = classnames self.nclass = len(classnames) self.pos_std = torch.tensor(pos_std) self.log_dim_mean = torch.tensor(log_dim_mean) self.log_dim_std = torch.tensor(log_dim_std) self.sigma = sigma self.nms_thresh = nms_thresh def encode_batch(self, objects, grids): # Encode batch element by element batch_encoded = [self.encode(objs, grid) for objs, grid in zip(objects, grids)] # Transpose batch return [torch.stack(t) for t in zip(*batch_encoded)] def encode(self, objects, grid): # Filter objects by class name objects = [obj for obj in objects if obj.classname in self.classnames] # Skip empty examples if len(objects) == 0: return self._encode_empty(grid) # Construct tensor representation of objects classids = torch.tensor([self.classnames.index(obj.classname) for obj in objects], device=grid.device) positions = grid.new([obj.position for obj in objects]) dimensions = grid.new([obj.dimensions for obj in objects]) angles = grid.new([obj.angle for obj in objects]) # Assign objects to locations on the grid _, indices = self._assign_to_grid( classids, positions, dimensions, angles, grid) mask = indices >= 0 # Encode object heatmaps heatmaps = self._encode_heatmaps(classids, positions, grid) # Encode positions, dimensions and angles pos_offsets = self._encode_positions(positions, indices, grid) dim_offsets = self._encode_dimensions(classids, dimensions, indices) ang_offsets = self._encode_angles(angles, indices) return heatmaps, pos_offsets, dim_offsets, ang_offsets, mask def _assign_to_grid(self, classids, positions, dimensions, angles, grid): # Compute grid centers centers = (grid[1:, 1:, :] + grid[:-1, :-1, :]) / 2. # Transform grid into object coordinate systems local_grid = utils.rotate(centers - positions.view(-1, 1, 1, 3), -angles.view(-1, 1, 1)) / dimensions.view(-1, 1, 1, 3) # Find all grid cells which lie within each object inside = (local_grid[..., [0, 2]].abs() <= 0.5).all(dim=-1) # Expand the mask in the class dimension NxDxW * NxC => NxCxDxW class_mask = classids.view(-1, 1) == torch.arange( len(self.classnames)).type_as(classids) class_inside = inside.unsqueeze(1) & class_mask[:, :, None, None] # Return positive locations and the id of the corresponding instance labels, indices = torch.max(class_inside, dim=0) return labels, indices def _encode_heatmaps(self, classids, positions, grid): centers = (grid[1:, 1:, [0, 2]] + grid[:-1, :-1, [0, 2]]) / 2. positions = positions.view(-1, 1, 1, 3)[..., [0, 2]] # Compute per-object heatmaps sqr_dists = (positions - centers).pow(2).sum(dim=-1) obj_heatmaps = torch.exp(-0.5 * sqr_dists / self.sigma ** 2) heatmaps = obj_heatmaps.new_zeros(self.nclass, *obj_heatmaps.size()[1:]) for i in range(self.nclass): mask = classids == i if mask.any(): heatmaps[i] = torch.max(obj_heatmaps[mask], dim=0)[0] return heatmaps def _encode_positions(self, positions, indices, grid): # Compute the center of each grid cell centers = (grid[1:, 1:] + grid[:-1, :-1]) / 2. # Encode positions into the grid C, D, W = indices.size() positions = positions.index_select(0, indices.view(-1)).view(C, D, W, 3) # Compute relative offsets and normalize pos_offsets = (positions - centers) / self.pos_std.to(positions) return pos_offsets.permute(0, 3, 1, 2) def _encode_dimensions(self, classids, dimensions, indices): # Convert mean and std to tensors log_dim_mean = self.log_dim_mean.to(dimensions)[classids] log_dim_std = self.log_dim_std.to(dimensions)[classids] # Compute normalized log scale offset dim_offsets = (torch.log(dimensions) - log_dim_mean) / log_dim_std # Encode dimensions to grid C, D, W = indices.size() dim_offsets = dim_offsets.index_select(0, indices.view(-1)) return dim_offsets.view(C, D, W, 3).permute(0, 3, 1, 2) def _encode_angles(self, angles, indices): # Compute rotation vector sin = torch.sin(angles)[indices] cos = torch.cos(angles)[indices] return torch.stack([cos, sin], dim=1) def _encode_empty(self, grid): depth, width, _ = grid.size() # Generate empty tensors heatmaps = grid.new_zeros((self.nclass, depth-1, width-1)) pos_offsets = grid.new_zeros((self.nclass, 3, depth-1, width-1)) dim_offsets = grid.new_zeros((self.nclass, 3, depth-1, width-1)) ang_offsets = grid.new_zeros((self.nclass, 2, depth-1, width-1)) mask = grid.new_zeros((self.nclass, depth-1, width-1)).bool() return heatmaps, pos_offsets, dim_offsets, ang_offsets, mask def decode(self, heatmaps, pos_offsets, dim_offsets, ang_offsets, grid): # Apply NMS to find positive heatmap locations peaks, scores, classids = self._decode_heatmaps(heatmaps) # Decode positions, dimensions and rotations positions = self._decode_positions(pos_offsets, peaks, grid) dimensions = self._decode_dimensions(dim_offsets, peaks) angles = self._decode_angles(ang_offsets, peaks) objects = list() for score, cid, pos, dim, ang in zip(scores, classids, positions, dimensions, angles): objects.append(utils.ObjectData( self.classnames[cid], pos, dim, ang, score)) return objects def _decode_heatmaps(self, heatmaps): peaks = non_maximum_suppression(heatmaps, self.sigma) scores = heatmaps[peaks] classids = torch.nonzero(peaks)[:, 0] return peaks, scores, classids def _decode_positions(self, pos_offsets, peaks, grid): # Compute the center of each grid cell centers = (grid[1:, 1:] + grid[:-1, :-1]) / 2. # Un-normalize grid offsets positions = pos_offsets.permute(0, 2, 3, 1) * self.pos_std + centers return positions[peaks] def _decode_dimensions(self, dim_offsets, peaks): dim_offsets = dim_offsets.permute(0, 2, 3, 1) dimensions = torch.exp( dim_offsets * self.log_dim_std + self.log_dim_mean) return dimensions[peaks] def _decode_angles(self, angle_offsets, peaks): cos, sin = torch.unbind(angle_offsets, 1) return torch.atan2(sin, cos)[peaks] def non_maximum_suppression(heatmaps, sigma=1.0, thresh=0.05, max_peaks=50): # Smooth with a Gaussian kernel num_class = heatmaps.size(0) kernel = utils.gaussian_kernel(sigma).to(heatmaps) kernel = kernel.expand(num_class, num_class, -1, -1) smoothed = F.conv2d( heatmaps[None], kernel, padding=int((kernel.size(2)-1)/2)) # Max pool over the heatmaps max_inds = F.max_pool2d(smoothed, 3, stride=1, padding=1, return_indices=True)[1].squeeze(0) # Find the pixels which correspond to the maximum indices _, height, width = heatmaps.size() flat_inds = torch.arange(height*width).type_as(max_inds).view(height, width) peaks = (flat_inds == max_inds) & (heatmaps > thresh) # Keep only the top N peaks if peaks.long().sum() > max_peaks: scores = heatmaps[peaks] scores, _ = torch.sort(scores, descending=True) peaks = peaks & (heatmaps > scores[max_peaks-1]) return peaks