import numpy as np from PIL import Image, ImageDraw import torch from torch.autograd import Variable import torch.nn.functional as F def expand_pose(pose): ''' param pose: N x 3 Takes 3-dimensional vectors, and massages them into 2x3 affine transformation matrices: [s,x,y] -> [[s,0,x], [0,s,y]] ''' n = pose.size(0) expansion_indices = Variable(torch.LongTensor([1, 0, 2, 0, 1, 3]).cuda(), requires_grad=False) zeros = Variable(torch.zeros(n, 1).cuda(), requires_grad=False) out = torch.cat([zeros, pose], dim=1) return torch.index_select(out, 1, expansion_indices).view(n, 2, 3) def pose_inv(pose): ''' param pose: N x 3 [s,x,y] -> [1/s,-x/s,-y/s] ''' N, _ = pose.size() ones = Variable(torch.ones(N, 1).cuda(), requires_grad=False) out = torch.cat([ones, -pose[:, 1:]], dim=1) out = out / pose[:, 0:1] return out def pose_inv_full(pose): ''' param pose: N x 6 Inverse the 2x3 transformer matrix. ''' N, _ = pose.size() b = pose.view(N, 2, 3)[:, :, 2:] # A^{-1} # Calculate determinant determinant = (pose[:, 0] * pose[:, 4] - pose[:, 1] * pose[:, 3] + 1e-8).view(N, 1) indices = Variable(torch.LongTensor([4, 1, 3, 0]).cuda()) scale = Variable(torch.Tensor([1, -1, -1, 1]).cuda()) A_inv = torch.index_select(pose, 1, indices) * scale / determinant A_inv = A_inv.view(N, 2, 2) # b' = - A^{-1} b b_inv = - A_inv.matmul(b).view(N, 2, 1) transformer_inv = torch.cat([A_inv, b_inv], dim=2) return transformer_inv def image_to_object(images, pose, object_size): ''' Inverse pose, crop and transform image patches. param images: (... x C x H x W) tensor param pose: (N x 3) tensor ''' N, pose_size = pose.size() n_channels, H, W = images.size()[-3:] images = images.view(N, n_channels, H, W) if pose_size == 3: transformer_inv = expand_pose(pose_inv(pose)) elif pose_size == 6: transformer_inv = pose_inv_full(pose) grid = F.affine_grid(transformer_inv, torch.Size((N, n_channels, object_size, object_size))) obj = F.grid_sample(images, grid) return obj def object_to_image(objects, pose, image_size): ''' param images: (N x C x H x W) tensor param pose: (N x 3) tensor ''' N, pose_size = pose.size() _, n_channels, _, _ = objects.size() if pose_size == 3: transformer = expand_pose(pose) elif pose_size == 6: transformer = pose.view(N, 2, 3) grid = F.affine_grid(transformer, torch.Size((N, n_channels, image_size, image_size))) components = F.grid_sample(objects, grid) return components def calculate_positions(pose): ''' Get the center x, y of the spatial transformer. ''' N, pose_size = pose.size() assert pose_size == 3, 'Only implemented pose_size == 3' # s, x, y s = pose[:, 0] xt = pose[:, 1] yt = pose[:, 2] x = (- xt / s + 1) / 2 y = (- yt / s + 1) / 2 return torch.stack([x, y], dim=1) def bounding_box(z_where, x_size): """This doesn't take into account interpolation, but it's close enough to be usable.""" s, x, y = z_where w = x_size / s h = x_size / s xtrans = -x / s * x_size / 2. ytrans = -y / s * x_size / 2. x = (x_size - w) / 2 + xtrans # origin is top left y = (x_size - h) / 2 + ytrans return (x, y), w, h def draw_components(images, pose): ''' Draw bounding box for the given pose. images: size (N x C x H x W), range [0, 1] pose: N x 3 ''' images = (images.cpu().numpy() * 255).astype(np.uint8) # [0, 255] pose = pose.cpu().numpy() N, C, H, W = images.shape for i in range(N): if C == 1: img = images[i][0] else: img = images[i].transpose((1, 2, 0)) img = Image.fromarray(img) draw = ImageDraw.Draw(img) (x, y), w, h = bounding_box(pose[i], H) draw.rectangle([x, y, x + w, y + h], outline=128) new_img = np.array(img) new_img[0, ...] = 255 # Add line new_img[-1, ...] = 255 # Add line if C == 1: new_img = new_img[np.newaxis, :, :] else: new_img = new_img.transpose((2, 0, 1)) images[i] = new_img # Back to torch tensor images = torch.FloatTensor(images / 255) return images