# -*- coding: utf-8 -*- # @Time : 2018-9-21 14:36 # @Author : xylon import torch from torch.nn import functional as F from skimage import transform from utils.math_utils import L2Norm def clip_patch(kpts_byxc, kpts_scale, kpts_ori, im_info, images, PSIZE): """ clip patch from im_C, im_S, im_info, im_raw. :param kpts_byxc: tensor #(B*topk, 4): the 4 correspond to (b, y, x, 0) each element in it has length B*topk :param kpts_scale: tensor(B*topk): image scale value corresponding to topk keypoints in all batch :param kpts_ori: tensor(B*topk, 2): image orintation value corresponding to topk keypoints in all batch :param im_info: tensor (B, 2): a list contain rescale ratio sh and sw :param images: tensor(B, 1, H, W): like 960*720 gray image before image rescaled to 320*240 :param PSIZE: should be cfg.PATCH.size :return: torch(B*topk, psize, psize): B*topk patch resized """ assert kpts_byxc.size(0) == kpts_scale.size(0) out_width = out_height = PSIZE device = kpts_byxc.device B, C, im_height, im_width = images.size() num_kp = kpts_byxc.size(0) # B*K max_y = int(im_height - 1) max_x = int(im_width - 1) y_t, x_t = torch.meshgrid( [ torch.linspace(-1, 1, out_height, dtype=torch.float, device=device), torch.linspace(-1, 1, out_width, dtype=torch.float, device=device), ] ) one_t = x_t.new_full(x_t.size(), fill_value=1) x_t = x_t.contiguous().view(-1) y_t = y_t.contiguous().view(-1) one_t = one_t.view(-1) grid = torch.stack((x_t, y_t, one_t)) # (3, out_width*out_height) grid = grid.view(-1) # (3*out_width*out_height) grid = grid.repeat(num_kp) # (numkp*3*out_width*out_height) # [num_kp, 3, 81] # this grid is designed to mask on keypoint from its left-up[-1, -1] to right-bottom[1, 1] grid = grid.view(num_kp, 3, -1) # # create 6D affine from scale and orientation # [s, 0, 0] [cos, -sin, 0] # [0, s, 0] * [sin, cos, 0] # [0, 0, 1] [0, 0, 1] # thetas = torch.eye( 2, 3, dtype=torch.float, device=device ) # [[ 1., 0., 0.],[ 0., 1., 0.]] (2, 3) thetas = thetas.unsqueeze(0).repeat(num_kp, 1, 1) # (num_kp, 2, 3) im_info = im_info[:, 0].unsqueeze(-1) # (B, 1) kpts_scale = kpts_scale.view(im_info.size(0), -1) / im_info # (B, topk) kpts_scale = kpts_scale.view(-1) / 2.0 # (numkp) thetas = thetas * kpts_scale[:, None, None] ones = torch.tensor([[[0, 0, 1]]], dtype=torch.float, device=device).repeat( num_kp, 1, 1 ) # (numkp, 1, 1) thetas = torch.cat((thetas, ones), 1) # (num_kp, 3, 3) # thetas like this # [sw, 0, 0] # [0, sh, 0] # [0, 0, 1] if kpts_ori is not None: cos = kpts_ori[:, 0].unsqueeze(-1) # [num_kp, 1] sin = kpts_ori[:, 1].unsqueeze(-1) # [num_kp, 1] zeros = cos.new_full(cos.size(), fill_value=0) ones = cos.new_full(cos.size(), fill_value=1) R = torch.cat((cos, -sin, zeros, sin, cos, zeros, zeros, zeros, ones), dim=-1) R = R.view(-1, 3, 3) thetas = torch.matmul(thetas, R) # Apply transformation to regular grid # [num_kp,3,3] * [num_kp,3,H*W] = [num_kp, 3, 81] # magnify grid to each keypoint scale T_g = torch.matmul(thetas, grid) x = T_g[:, 0, :] # (numkp, 81) y = T_g[:, 1, :] # (numkp, 81) # get each keypoint x kp_x_ofst = kpts_byxc[:, 2].view(B, -1).float() / im_info # (B, topk) kp_x_ofst = kp_x_ofst.view(-1, 1) # (numkp, 1) get each keypoint x # get each keypoint y kp_y_ofst = kpts_byxc[:, 1].view(B, -1).float() / im_info # (B, topk) kp_y_ofst = kp_y_ofst.view(-1, 1) # (numkp, 1) get each keypoint y # centerize on keypoints # [num_kp,81] + # [num_kp,1] # move grid center on each keypoint x = x + kp_x_ofst # [num_kp,81] + # [num_kp,1] # move grid center on each keypoint y = y + kp_y_ofst x = x.view(-1) # [num_kp*81] y = y.view(-1) # [num_kp*81] # interpolation x0 = x.floor().long() # [num_kp*81] x1 = x0 + 1 # [num_kp*81] y0 = y.floor().long() # [num_kp*81] y1 = y0 + 1 # [num_kp*81] x0 = x0.clamp(min=0, max=max_x) # [num_kp*81] x1 = x1.clamp(min=0, max=max_x) # [num_kp*81] y0 = y0.clamp(min=0, max=max_y) # [num_kp*81] y1 = y1.clamp(min=0, max=max_y) # [num_kp*81] dim2 = im_width dim1 = im_width * im_height batch_inds = kpts_byxc[:, 0].unsqueeze( -1 ) # (num_kp, 1) get each keypoint batch number base = batch_inds.repeat( 1, out_height * out_width ) # [num_kp, 81] # means batch indexes correspond to each grid pixel # [num_kp*81] # correspond to each grid pixel start index if all pixel flatten to a vector base = base.view(-1) * dim1 base_y0 = ( base + y0 * dim2 ) # correspond each grid pixel y0 pixel if all pixel flatten to a vector base_y1 = ( base + y1 * dim2 ) # correspond each grid pixel y1 pixel if all pixel flatten to a vector idx_a = ( base_y0 + x0 ) # correspond left_up point pixel index if all pixel flatten to a vector idx_b = base_y1 + x0 # left-bottom pixel idx_c = base_y0 + x1 # right-up pixel idx_d = base_y1 + x1 # right-bottom pixel im_flat = images.view(-1) # [B*height*width] # flatten all pixel # [num_kp*81] # get pixel value in index idx_a Ia = im_flat.gather(0, idx_a) # [num_kp*81] # get pixel value in index idx_b Ib = im_flat.gather(0, idx_b) # [num_kp*81] # get pixel value in index idx_c Ic = im_flat.gather(0, idx_c) # [num_kp*81] # get pixel value in index idx_d Id = im_flat.gather(0, idx_d) x0_f = x0.float() # [num_kp*81] x1_f = x1.float() # [num_kp*81] y0_f = y0.float() # [num_kp*81] y1_f = y1.float() # [num_kp*81] # [num_kp*81] # interpolation weight which is the distance from x to x1 times y to y1 wa = (x1_f - x) * (y1_f - y) wb = (x1_f - x) * (y - y0_f) # [num_kp*81] # interpolation weight wc = (x - x0_f) * (y1_f - y) # [num_kp*81] # interpolation weight wd = (x - x0_f) * (y - y0_f) # [num_kp*81] # interpolation weight output = ( wa * Ia + wb * Ib + wc * Ic + wd * Id ) # interpolation value in each keypoints grid output = output.view(num_kp, out_height, out_width) return output.unsqueeze(1) def warp(im1_data, homo21): """ warp im1 to im2 cause we get pixel valu ein im2 from im1 so we warp grid in im2 to im1 that we need homo21 :param im1_data: (B, H, W, C) :param homo21: (B, 3, 3) :return: out_image (B, H, W, C) """ B, imH, imW, C = im1_data.size() outH, outW = imH, imW gy, gx = torch.meshgrid([torch.arange(outH), torch.arange(outW)]) gx, gy = gx.float().unsqueeze(-1), gy.float().unsqueeze(-1) ones = gy.new_full(gy.size(), fill_value=1) grid = torch.cat((gx, gy, ones), -1) # (H, W, 3) grid = grid.unsqueeze(0) # (1, H, W, 3) grid = grid.repeat(B, 1, 1, 1) # (B, H, W, 3) grid = grid.view(grid.size(0), -1, grid.size(-1)) # (B, H*W, 3) grid = grid.permute(0, 2, 1) # (B, 3, H*W) grid = grid.type_as(homo21).to(homo21.device) # (B, 3, 3) matmul (B, 3, H*W) => (B, 3, H*W) grid_w = torch.matmul(homo21, grid) grid_w = grid_w.permute(0, 2, 1) # (B, H*W, 3) grid_w = grid_w.div(grid_w[:, :, 2].unsqueeze(-1) + 1e-8) # (B, H*W, 3) grid_w = grid_w.view(B, outH, outW, -1)[:, :, :, :2] # (B, H, W, 2) grid_w[:, :, :, 0] = grid_w[:, :, :, 0].div(imW - 1) * 2 - 1 grid_w[:, :, :, 1] = grid_w[:, :, :, 1].div(imH - 1) * 2 - 1 out_image = torch.nn.functional.grid_sample( im1_data.permute(0, 3, 1, 2), grid_w ) # (B, C, H, W) return out_image.permute(0, 2, 3, 1) def filtbordmask(imscore, radius): bs, height, width, c = imscore.size() mask = imscore.new_full( (1, height - 2 * radius, width - 2 * radius, 1), fill_value=1 ) mask = F.pad( input=mask, pad=(0, 0, radius, radius, radius, radius, 0, 0), mode="constant", value=0, ) return mask def filter_border(imscore, radius=8): imscore = imscore * filtbordmask(imscore, radius=radius) return imscore def nms(input, thresh=0.0, ksize=5): """ non maximum depression in each pixel if it is not maximum probability in its ksize*ksize range :param input: (B, H, W, 1) :param thresh: float :param ksize: int :return: mask (B, H, W, 1) """ dtype, device = input.dtype, input.device batch, height, width, channel = input.size() pad = ksize // 2 zeros = torch.zeros_like(input) input = torch.where(input < thresh, zeros, input) input_pad = F.pad( input=input, pad=(0, 0, 2 * pad, 2 * pad, 2 * pad, 2 * pad, 0, 0), mode="constant", value=0, ) slice_map = torch.tensor([], dtype=input_pad.dtype, device=device) for i in range(ksize): for j in range(ksize): slice = input_pad[:, i : height + 2 * pad + i, j : width + 2 * pad + j, :] slice_map = torch.cat((slice_map, slice), -1) max_slice = slice_map.max(dim=-1, keepdim=True)[0] center_map = slice_map[:, :, :, slice_map.size(-1) // 2].unsqueeze(-1) mask = torch.ge(center_map, max_slice) mask = mask[:, pad : height + pad, pad : width + pad, :] return mask.type_as(input) def topk_map(maps, k=512): """ find the top k maximum pixel probability in a maps :param maps: (B, H, W, 1) :param k: int :return: mask (B, H, W, 1) """ batch, height, width, _ = maps.size() maps_flat = maps.view(batch, -1) indices = maps_flat.sort(dim=-1, descending=True)[1][:, :k] batch_idx = ( torch.arange(0, batch, dtype=indices.dtype, device=indices.device) .unsqueeze(-1) .repeat(1, k) ) batch_idx = batch_idx.view(-1).cpu().detach().numpy() row_idx = indices.contiguous().view(-1).cpu().detach().numpy() batch_indexes = (batch_idx, row_idx) topk_mask_flat = torch.zeros(maps_flat.size(), dtype=torch.uint8).to(maps.device) topk_mask_flat[batch_indexes] = 1 mask = topk_mask_flat.view(batch, height, width, -1) return mask def get_gauss_filter_weight(ksize, sig): """ generate a gaussian kernel :param ksize: int :param sig: float :return: numpy(ksize*ksize) """ mu_x = mu_y = ksize // 2 if sig == 0: psf = torch.zeros((ksize, ksize)).float() psf[mu_y, mu_x] = 1.0 else: sig = torch.tensor(sig).float() x = torch.arange(ksize)[None, :].repeat(ksize, 1).float() y = torch.arange(ksize)[:, None].repeat(1, ksize).float() psf = torch.exp( -((x - mu_x) ** 2 / (2 * sig ** 2) + (y - mu_y) ** 2 / (2 * sig ** 2)) ) return psf def soft_nms_3d(scale_logits, ksize, com_strength): """ calculate probability for each pixel in each scale space :param scale_logits: (B, H, W, C) :param ksize: int :param com_strength: magnify parameter :return: probability for each pixel in each scale, size is (B, H, W, C) """ num_scales = scale_logits.size(-1) max_each_scale = F.max_pool2d( input=scale_logits.permute(0, 3, 1, 2), kernel_size=ksize, padding=ksize // 2, stride=1, ).permute( 0, 2, 3, 1 ) # (B, H, W, C) max_all_scale, max_all_scale_idx = max_each_scale.max( dim=-1, keepdim=True ) # (B, H, W, 1) exp_maps = torch.exp(com_strength * (scale_logits - max_all_scale)) # (B, H, W, C) sum_exp = F.conv2d( input=exp_maps.permute(0, 3, 1, 2), weight=exp_maps.new_full([1, num_scales, ksize, ksize], fill_value=1), stride=1, padding=ksize // 2, ).permute( 0, 2, 3, 1 ) # (B, H, W, 1) probs = exp_maps / (sum_exp + 1e-8) return probs def soft_max_and_argmax_1d( input, orint_maps, scale_list, com_strength1, com_strength2, dim=-1, keepdim=True ): """ input should be pixel probability in each scale this function calculate the final pixel probability summary from all scale and each pixel correspond scale :param input: scale_probs(B, H, W, 10) :param orint_maps: (B, H, W, 10, 2) :param dim: final channel :param scale_list: scale space list :param keepdim: kepp dimension :param com_strength1: magnify argument of score :param com_strength2: magnify argument of scale :return: score_map(B, H, W, 1), scale_map(B, H, W, 1), (orint_map(B, H, W, 1, 2)) """ inputs_exp1 = torch.exp( com_strength1 * (input - torch.max(input, dim=dim, keepdim=True)[0]) ) input_softmax1 = inputs_exp1 / ( inputs_exp1.sum(dim=dim, keepdim=True) + 1e-8 ) # (B, H, W, 10) inputs_exp2 = torch.exp( com_strength2 * (input - torch.max(input, dim=dim, keepdim=True)[0]) ) input_softmax2 = inputs_exp2 / ( inputs_exp2.sum(dim=dim, keepdim=True) + 1e-8 ) # (B, H, W, 10) score_map = torch.sum(input * input_softmax1, dim=dim, keepdim=keepdim) scale_list_shape = [1] * len(input.size()) scale_list_shape[dim] = -1 scale_list = scale_list.view(scale_list_shape).to(input_softmax2.device) scale_map = torch.sum(scale_list * input_softmax2, dim=dim, keepdim=keepdim) if orint_maps is not None: orint_map = torch.sum( orint_maps * input_softmax1.unsqueeze(-1), dim=dim - 1, keepdim=keepdim ) # (B, H, W, 1, 2) orint_map = L2Norm(orint_map, dim=-1) return score_map, scale_map, orint_map else: return score_map, scale_map def im_rescale(im, output_size): h, w = im.shape[:2] if isinstance(output_size, int): if h > w: new_h, new_w = output_size * h / w, output_size else: new_h, new_w = output_size, output_size * w / h else: new_h, new_w = output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(im, (new_h, new_w), mode="constant") return img, h, w, new_w / w, new_h / h