import torch import torch.nn as nn import torch.nn.functional as F from .. import utils EPSILON = 1e-6 class OFT(nn.Module): def __init__(self, channels, cell_size, grid_height, scale=1): super().__init__() y_corners = torch.arange(0, grid_height, cell_size) - grid_height / 2. y_corners = F.pad(y_corners.view(-1, 1, 1, 1), [1, 1]) self.register_buffer('y_corners', y_corners) # self.conv3d = nn.Conv2d((len(y_corners)-1) * channels, channels,1) self.conv3d = nn.Linear((len(y_corners)-1) * channels, channels) self.scale = scale def forward(self, features, calib, grid): # Expand the grid in the y dimension corners = grid.unsqueeze(1) + self.y_corners.view(-1, 1, 1, 3) # Project grid corners to image plane and normalize to [-1, 1] img_corners = utils.perspective(calib.view(-1, 1, 1, 1, 3, 4), corners) # Normalize to [-1, 1] img_height, img_width = features.size()[2:] img_size = corners.new([img_width, img_height]) / self.scale norm_corners = (2 * img_corners / img_size - 1).clamp(-1, 1) # Get top-left and bottom-right coordinates of voxel bounding boxes bbox_corners = torch.cat([ torch.min(norm_corners[:, :-1, :-1, :-1], norm_corners[:, :-1, 1:, :-1]), torch.max(norm_corners[:, 1:, 1:, 1:], norm_corners[:, 1:, :-1, 1:]) ], dim=-1) batch, _, depth, width, _ = bbox_corners.size() bbox_corners = bbox_corners.flatten(2, 3) # Compute the area of each bounding box area = ((bbox_corners[..., 2:] - bbox_corners[..., :2]).prod(dim=-1) \ * img_height * img_width * 0.25 + EPSILON).unsqueeze(1) visible = (area > EPSILON) # Sample integral image at bounding box locations integral_img = integral_image(features) top_left = F.grid_sample(integral_img, bbox_corners[..., [0, 1]]) btm_right = F.grid_sample(integral_img, bbox_corners[..., [2, 3]]) top_right = F.grid_sample(integral_img, bbox_corners[..., [2, 1]]) btm_left = F.grid_sample(integral_img, bbox_corners[..., [0, 3]]) # Compute voxel features (ignore features which are not visible) vox_feats = (top_left + btm_right - top_right - btm_left) / area vox_feats = vox_feats * visible.float() # vox_feats = vox_feats.view(batch, -1, depth, width) vox_feats = vox_feats.permute(0, 3, 1, 2).flatten(0, 1).flatten(1, 2) # Flatten to orthographic feature map ortho_feats = self.conv3d(vox_feats).view(batch, depth, width, -1) ortho_feats = F.relu(ortho_feats.permute(0, 3, 1, 2), inplace=True) # ortho_feats = F.relu(self.conv3d(vox_feats)) # Block gradients to pixels which are not visible in the image return ortho_feats def integral_image(features): return torch.cumsum(torch.cumsum(features, dim=-1), dim=-2)