Python torch.nn.functional.grid_sample() Examples

The following are 30 code examples of torch.nn.functional.grid_sample(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn.functional , or try the search function .
Example #1
Source File: module_util.py    From BasicSR with Apache License 2.0 11 votes vote down vote up
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
    """Warp an image or feature map with optical flow
    Args:
        x (Tensor): size (N, C, H, W)
        flow (Tensor): size (N, H, W, 2), normal value
        interp_mode (str): 'nearest' or 'bilinear'
        padding_mode (str): 'zeros' or 'border' or 'reflection'

    Returns:
        Tensor: warped image or feature map
    """
    assert x.size()[-2:] == flow.size()[1:3]
    B, C, H, W = x.size()
    # mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False
    grid = grid.type_as(x)
    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
    return output 
Example #2
Source File: DDPAE_utils.py    From DDPAE-video-prediction with MIT License 7 votes vote down vote up
def image_to_object(images, pose, object_size):
  '''
  Inverse pose, crop and transform image patches.
  param images: (... x C x H x W) tensor
  param pose: (N x 3) tensor
  '''
  N, pose_size = pose.size()
  n_channels, H, W = images.size()[-3:]
  images = images.view(N, n_channels, H, W)
  if pose_size == 3:
    transformer_inv = expand_pose(pose_inv(pose))
  elif pose_size == 6:
    transformer_inv = pose_inv_full(pose)

  grid = F.affine_grid(transformer_inv,
                       torch.Size((N, n_channels, object_size, object_size)))
  obj = F.grid_sample(images, grid)
  return obj 
Example #3
Source File: DDPAE_utils.py    From DDPAE-video-prediction with MIT License 7 votes vote down vote up
def object_to_image(objects, pose, image_size):
  '''
  param images: (N x C x H x W) tensor
  param pose: (N x 3) tensor
  '''
  N, pose_size = pose.size()
  _, n_channels, _, _ = objects.size()
  if pose_size == 3:
    transformer = expand_pose(pose)
  elif pose_size == 6:
    transformer = pose.view(N, 2, 3)

  grid = F.affine_grid(transformer,
                       torch.Size((N, n_channels, image_size, image_size)))
  components = F.grid_sample(objects, grid)
  return components 
Example #4
Source File: misc.py    From pytorch-semantic-segmentation with MIT License 6 votes vote down vote up
def forward(self, x):
        x_shape = x.size()  # (b, c, h, w)
        offset = self.offset_filter(x)  # (b, 2*c, h, w)
        offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1)  # (b, c, h, w)
        offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))  # (b*c, h, w)
        offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3]))  # (b*c, h, w)
        if not self.input_shape or self.input_shape != x_shape:
            self.input_shape = x_shape
            grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2]))  # (h, w)
            grid_w = torch.Tensor(grid_w)
            grid_h = torch.Tensor(grid_h)
            if self.cuda:
                grid_w = grid_w.cuda()
                grid_h = grid_h.cuda()
            self.grid_w = nn.Parameter(grid_w)
            self.grid_h = nn.Parameter(grid_h)
        offset_w = offset_w + self.grid_w  # (b*c, h, w)
        offset_h = offset_h + self.grid_h  # (b*c, h, w)
        x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1)  # (b*c, 1, h, w)
        x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3))  # (b*c, h, w)
        x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3]))  # (b, c, h, w)
        x = self.regular_filter(x)
        return x 
Example #5
Source File: utils.py    From PLARD with MIT License 6 votes vote down vote up
def interp(input, output_size, mode='bilinear'):
    n, c, ih, iw = input.shape
    oh, ow = output_size

    # normalize to [-1, 1]
    h = torch.arange(0, oh) / (oh-1) * 2 - 1
    w = torch.arange(0, ow) / (ow-1) * 2 - 1

    grid = torch.zeros(oh, ow, 2)
    grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
    grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
    grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
    grid = Variable(grid)
    if input.is_cuda:
        grid = grid.cuda()

    return F.grid_sample(input, grid, mode=mode) 
Example #6
Source File: pairwise.py    From airlab with Apache License 2.0 6 votes vote down vote up
def forward(self, displacement):

        # compute displacement field
        displacement = self._grid + displacement

        # compute current mask
        mask = super(NCC, self).GetCurrentMask(displacement)

        self._warped_moving_image = F.grid_sample(self._moving_image.image, displacement)

        moving_image_valid = th.masked_select(self._warped_moving_image, mask)
        fixed_image_valid = th.masked_select(self._fixed_image.image, mask)


        value = -1.*th.sum((fixed_image_valid - th.mean(fixed_image_valid))*(moving_image_valid - th.mean(moving_image_valid)))\
                /th.sqrt(th.sum((fixed_image_valid - th.mean(fixed_image_valid))**2)*th.sum((moving_image_valid - th.mean(moving_image_valid))**2) + 1e-10)

        return value 
Example #7
Source File: pairwise.py    From airlab with Apache License 2.0 6 votes vote down vote up
def forward(self, displacement):

        # compute displacement field
        displacement = self._grid + displacement

        # compute current mask
        mask = super(MSE, self).GetCurrentMask(displacement)

        # warp moving image with dispalcement field
        self.warped_moving_image = F.grid_sample(self._moving_image.image, displacement)

        # compute squared differences
        value = (self.warped_moving_image - self._fixed_image.image).pow(2)

        # mask values
        value = th.masked_select(value, mask)

        return self.return_loss(value) 
Example #8
Source File: pointrend.py    From SegmenTron with Apache License 2.0 6 votes vote down vote up
def point_sample(input, point_coords, **kwargs):
    """
    From Detectron2, point_features.py#19
    A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
    Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
    [0, 1] x [0, 1] square.
    Args:
        input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
        point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
        [0, 1] x [0, 1] normalized point coordinates.
    Returns:
        output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
            features for points in `point_coords`. The features are obtained via bilinear
            interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
    """
    add_dim = False
    if point_coords.dim() == 3:
        add_dim = True
        point_coords = point_coords.unsqueeze(2)
    output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
    if add_dim:
        output = output.squeeze(3)
    return output 
Example #9
Source File: models.py    From adversarial-object-removal with MIT License 6 votes vote down vote up
def forward(self, x, extra_inp=None):
        # First upsample the input with transposed/ upsampling
        # Compute the warp co-ordinates using a conv
        # Warp the conv
        up_out = self.upsampLayer(x)
        filt_out = self.convFilter(up_out if extra_inp is None else torch.cat([up_out,extra_inp], dim=1))

        if self.use_deform:
            cord_offset = self.coordfilter(up_out)
            reg_grid = Variable(torch.FloatTensor(np.stack(np.meshgrid(np.linspace(-1,1, up_out.size(2)), np.linspace(-1,1, up_out.size(3))))).cuda(),requires_grad=False)
            deform_grid = reg_grid.detach() + F.tanh(cord_offset)
            deformed_out = F.grid_sample(filt_out, deform_grid.transpose(1,3).transpose(1,2), mode='bilinear', padding_mode='zeros')
            feat_out = (deform_grid, reg_grid, cord_offset)
        else:
            deformed_out = filt_out
            feat_out = []

        #Deformed out
        return deformed_out, feat_out 
Example #10
Source File: point_sample.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def point_sample(input, points, align_corners=False, **kwargs):
    """A wrapper around :function:`grid_sample` to support 3D point_coords
    tensors Unlike :function:`torch.nn.functional.grid_sample` it assumes
    point_coords to lie inside [0, 1] x [0, 1] square.

    Args:
        input (Tensor): Feature map, shape (N, C, H, W).
        points (Tensor): Image based absolute point coordinates (normalized),
            range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
        align_corners (bool): Whether align_corners. Default: False

    Returns:
        Tensor: Features of `point` on `input`, shape (N, C, P) or
            (N, C, Hgrid, Wgrid).
    """

    add_dim = False
    if points.dim() == 3:
        add_dim = True
        points = points.unsqueeze(2)
    output = F.grid_sample(
        input, denormalize(points), align_corners=align_corners, **kwargs)
    if add_dim:
        output = output.squeeze(3)
    return output 
Example #11
Source File: imwrap.py    From DSMnet with Apache License 2.0 6 votes vote down vote up
def imwrap_BCHW0(im_src, disp):
    # imwrap
    bn, c, h, w = im_src.shape
    row = torch.linspace(-1, 1, w)
    col = torch.linspace(-1, 1, h)
    grid = torch.zeros(bn, h, w, 2)
    for n in range(bn):
        for i in range(h):
            grid[n, i, :, 0] = row
        for i in range(w):
            grid[n, :, i, 1] = col
    grid = Variable(grid, requires_grad=True).type_as(im_src)
    grid[:, :, :, 0] = grid[:, :, :, 0] - disp.squeeze(1)*2/w
    #print disp[-1, -1, -1], grid[-1, -1, -1, 0]
    im_src.clamp(min=1e-6)
    im_wrap = F.grid_sample(im_src, grid)
    return im_wrap 
Example #12
Source File: stn.py    From LaSO with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, x):
        #
        # Calculate the transform
        #
        xs = self.localization(x)
        xs = xs.view(-1, 32*7*7)

        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())

        #
        # transform the input
        #
        x = F.grid_sample(x, grid)

        return x 
Example #13
Source File: two_algo_face_rotator.py    From talking-head-anime-demo with MIT License 6 votes vote down vote up
def forward(self, image: Tensor, pose: Tensor):
        n = image.size(0)
        c = image.size(1)
        h = image.size(2)
        w = image.size(3)

        pose = pose.unsqueeze(2).unsqueeze(3)
        pose = pose.expand(pose.size(0), pose.size(1), image.size(2), image.size(3))
        x = torch.cat([image, pose], dim=1)
        y = self.main_body(x)

        color_change = self.pumarola_color_change(y)
        alpha_mask = self.pumarola_alpha_mask(y)
        color_changed = alpha_mask * image + (1 - alpha_mask) * color_change

        grid_change = torch.transpose(self.zhou_grid_change(y).view(n, 2, h * w), 1, 2).view(n, h, w, 2)
        device = self.zhou_grid_change.weight.device
        identity = torch.Tensor([[1, 0, 0], [0, 1, 0]]).to(device).unsqueeze(0).repeat(n, 1, 1)
        base_grid = affine_grid(identity, [n, c, h, w], align_corners=self.align_corners)
        grid = base_grid + grid_change
        resampled = grid_sample(image, grid, mode='bilinear', padding_mode='border', align_corners=self.align_corners)

        return [color_changed, resampled, color_change, alpha_mask, grid_change, grid] 
Example #14
Source File: point_sample.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def __init__(self, out_size, spatial_scale, aligned=True):
        """Simple RoI align in PointRend, faster than standard RoIAlign.

        Args:
            out_size (tuple[int]): h, w
            spatial_scale (float): scale the input boxes by this number
            aligned (bool): if False, use the legacy implementation in
                MMDetection, align_corners=True will be used in F.grid_sample.
                If True, align the results more perfectly.
        """

        super(SimpleRoIAlign, self).__init__()
        self.out_size = _pair(out_size)
        self.spatial_scale = float(spatial_scale)
        # to be consistent with other RoI ops
        self.use_torchvision = False
        self.aligned = aligned 
Example #15
Source File: offset_block.py    From openseg.pytorch with MIT License 6 votes vote down vote up
def forward(self, x, offset_map):
        n, c, h, w = x.size()
        
        if self.coord_map is None or self.coord_map[0].size() != offset_map.size()[2:]:
            self.coord_map = self._gen_coord_map(h, w)
            self.norm_factor = torch.cuda.FloatTensor([(w-1) / 2, (h-1) / 2])
        
        # offset to absolute coordinate
        grid_h = offset_map[:, 0] + self.coord_map[0]                               # (N, H, W)
        grid_w = offset_map[:, 1] + self.coord_map[1]                               # (N, H, W)

        # scale to [-1, 1], order of grid: [x, y] (i.e., [w, h])
        grid = torch.stack([grid_w, grid_h], dim=-1) / self.norm_factor - 1.        # (N, H, W, 2)

        # use grid to obtain output feature
        feats = F.grid_sample(x, grid, padding_mode='border')                       # (N, C, H, W)
        
        return feats 
Example #16
Source File: segfix.py    From openseg.pytorch with MIT License 6 votes vote down vote up
def shift(x, offset):
    """
    x: h x w
    offset: 2 x h x w
    """
    h, w = x.shape
    x = torch.from_numpy(x).unsqueeze(0)
    offset = torch.from_numpy(offset).unsqueeze(0)
    coord_map = gen_coord_map(h, w)
    norm_factor = torch.FloatTensor([(w-1)/2, (h-1)/2])
    grid_h = offset[:, 0]+coord_map[0]
    grid_w = offset[:, 1]+coord_map[1]
    grid = torch.stack([grid_w, grid_h], dim=-1) / norm_factor - 1
    x = F.grid_sample(x.unsqueeze(1).float(), grid, padding_mode='border', mode='bilinear').squeeze().numpy()
    x = np.round(x)
    return x.astype(np.uint8) 
Example #17
Source File: module_util.py    From IKC with Apache License 2.0 6 votes vote down vote up
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
    """Warp an image or feature map with optical flow
    Args:
        x (Tensor): size (N, C, H, W)
        flow (Tensor): size (N, H, W, 2), normal value
        interp_mode (str): 'nearest' or 'bilinear'
        padding_mode (str): 'zeros' or 'border' or 'reflection'

    Returns:
        Tensor: warped image or feature map
    """
    assert x.size()[-2:] == flow.size()[1:3]
    B, C, H, W = x.size()
    # mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False
    grid = grid.type_as(x)
    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
    return output 
Example #18
Source File: superpointnet.py    From imgclsmob with MIT License 6 votes vote down vote up
def forward(self, x, pts_list):
        x_height, x_width = x.size()[-2:]

        coarse_desc_map = self.head(x)
        coarse_desc_map = F.normalize(coarse_desc_map)

        descriptors_list = []
        for i, pts in enumerate(pts_list):
            pts = pts.float()
            pts[:, 0] = pts[:, 0] / (0.5 * x_height * self.reduction) - 1.0
            pts[:, 1] = pts[:, 1] / (0.5 * x_width * self.reduction) - 1.0
            if self.transpose_descriptors:
                pts = torch.index_select(pts, dim=1, index=torch.tensor([1, 0], device=pts.device))
            pts = pts.unsqueeze(0).unsqueeze(0)
            descriptors = F.grid_sample(coarse_desc_map[i:(i + 1)], pts)
            descriptors = descriptors.squeeze(0).squeeze(1)
            descriptors = descriptors.transpose(0, 1)
            descriptors = F.normalize(descriptors)
            descriptors_list.append(descriptors)

        return descriptors_list 
Example #19
Source File: model.py    From voxelmorph with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, src, flow):   
        """
        Push the src and flow through the spatial transform block
            :param src: the original moving image
            :param flow: the output from the U-Net
        """
        new_locs = self.grid + flow 

        shape = flow.shape[2:]

        # Need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:,i,...] = 2*(new_locs[:,i,...]/(shape[i]-1) - 0.5)

        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1) 
            new_locs = new_locs[..., [1,0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1) 
            new_locs = new_locs[..., [2,1,0]]

        return nnf.grid_sample(src, new_locs, mode=self.mode) 
Example #20
Source File: flow_utils.py    From swiftnet with GNU General Public License v3.0 6 votes vote down vote up
def offset_flow(img, flow):
    '''
    :param img: torch.FloatTensor of shape NxCxHxW
    :param flow: torch.FloatTensor of shape NxHxWx2
    :return: torch.FloatTensor of shape NxCxHxW
    '''
    N, C, H, W = img.shape
    # generate identity sampling grid
    gx, gy = torch.meshgrid(torch.arange(H), torch.arange(W))
    gx = gx.float().div(gx.max() - 1).view(1, H, W, 1)
    gy = gy.float().div(gy.max() - 1).view(1, H, W, 1)
    grid = torch.cat([gy, gx], dim=-1).mul(2.).sub(1)
    # generate normalized flow field
    flown = flow.clone()
    flown[..., 0] /= W
    flown[..., 1] /= H
    # calculate offset field
    grid += flown
    return F.grid_sample(img, grid), grid 
Example #21
Source File: graph.py    From PPGNet with MIT License 6 votes vote down vote up
def forward(self, feat, coord_st, coord_ed):
        _, ch, h, w = feat.size()
        num_st, num_ed = coord_st.size(0), coord_ed.size(0)
        assert coord_st.size(1) == 3 and coord_ed.size(1) == 3
        assert (coord_st[:, 0] == coord_st[0, 0]).all() and (coord_ed[:, 0] == coord_st[0, 0]).all()
        bs = coord_st[0, 0].item()
        # construct bounding boxes from junction points
        with torch.no_grad():
            coord_st = coord_st[:, 1:] * self.scale
            coord_ed = coord_ed[:, 1:] * self.scale
            coord_st = coord_st.unsqueeze(1).expand(num_st, num_ed, 2)
            coord_ed = coord_ed.unsqueeze(0).expand(num_st, num_ed, 2)
            arr_st2ed = coord_ed - coord_st
            sample_grid = torch.linspace(0, 1, steps=self.align_size).to(feat).view(1, 1, self.align_size).expand(num_st, num_ed, self.align_size)
            sample_grid = torch.einsum("ijd,ijs->ijsd", (arr_st2ed, sample_grid)) + coord_st.view(num_st, num_ed, 1, 2).expand(num_st, num_ed, self.align_size, 2)
            sample_grid = sample_grid.view(num_st, num_ed, self.align_size, 2)
            sample_grid[..., 0] = sample_grid[..., 0] / (w - 1) * 2 - 1
            sample_grid[..., 1] = sample_grid[..., 1] / (h - 1) * 2 - 1

        output = F.grid_sample(feat[int(bs)].view(1, ch, h, w).expand(num_st, ch, h, w), sample_grid)
        assert output.size() == (num_st, ch, num_ed, self.align_size)
        output = output.permute(0, 2, 1, 3).contiguous()

        return output 
Example #22
Source File: dense_pdd_net_v01.py    From pdd_net with Apache License 2.0 6 votes vote down vote up
def augmentAffine(img_in, seg_in, strength=0.05):
    """
    3D affine augmentation on image and segmentation mini-batch on GPU.
    (affine transf. is centered: trilinear interpolation and zero-padding used for sampling)
    :input: img_in batch (torch.cuda.FloatTensor), seg_in batch (torch.cuda.LongTensor)
    :return: augmented BxCxTxHxW image batch (torch.cuda.FloatTensor), augmented BxTxHxW seg batch (torch.cuda.LongTensor)
    """
    B,C,D,H,W = img_in.size()
    affine_matrix = (torch.eye(3,4).unsqueeze(0) + torch.randn(B, 3, 4) * strength).to(img_in.device)

    meshgrid = F.affine_grid(affine_matrix,torch.Size((B,1,D,H,W)))

    img_out = F.grid_sample(img_in, meshgrid,padding_mode='border')
    seg_out = F.grid_sample(seg_in.float().unsqueeze(1), meshgrid, mode='nearest').long().squeeze(1)

    return img_out, seg_out 
Example #23
Source File: model.py    From photometric-mesh-optim with MIT License 6 votes vote down vote up
def compare_valid_index(self,opt,index_s,index_V,points):
		batch_size = len(index_V)
		index_synth_list = []
		index_vec = index_s.reshape(batch_size,opt.H*opt.W)
		# get index map from 4 integer corners
		for Y in [points[...,1].floor(),points[...,1].ceil()]:
			for X in [points[...,0].floor(),points[...,0].ceil()]:
				grid_sample = Y.long().clamp(min=0,max=opt.H-1)*opt.W \
							 +X.long().clamp(min=0,max=opt.W-1)
				grid_sample_vec = grid_sample.reshape(batch_size,opt.H*opt.W)
				index_synth_vec = torch.gather(index_vec,1,grid_sample_vec)
				index_synth = index_synth_vec.reshape(batch_size,opt.H,opt.W)
				index_synth_list.append(index_synth)
		# consider only points where projected coordinates have consistent triangle indices
		valid_index = (index_synth_list[0]==index_V) \
					 &(index_synth_list[1]==index_V) \
					 &(index_synth_list[2]==index_V) \
					 &(index_synth_list[3]==index_V)
		return valid_index 
Example #24
Source File: module_util.py    From real-world-sr with MIT License 6 votes vote down vote up
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
    """Warp an image or feature map with optical flow
    Args:
        x (Tensor): size (N, C, H, W)
        flow (Tensor): size (N, H, W, 2), normal value
        interp_mode (str): 'nearest' or 'bilinear'
        padding_mode (str): 'zeros' or 'border' or 'reflection'

    Returns:
        Tensor: warped image or feature map
    """
    assert x.size()[-2:] == flow.size()[1:3]
    B, C, H, W = x.size()
    # mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False
    grid = grid.type_as(x)
    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
    return output 
Example #25
Source File: point_sample.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def point_sample(input, points, align_corners=False, **kwargs):
    """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
    Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
    lie inside ``[0, 1] x [0, 1]`` square.

    Args:
        input (Tensor): Feature map, shape (N, C, H, W).
        points (Tensor): Image based absolute point coordinates (normalized),
            range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
        align_corners (bool): Whether align_corners. Default: False

    Returns:
        Tensor: Features of `point` on `input`, shape (N, C, P) or
            (N, C, Hgrid, Wgrid).
    """

    add_dim = False
    if points.dim() == 3:
        add_dim = True
        points = points.unsqueeze(2)
    output = F.grid_sample(
        input, denormalize(points), align_corners=align_corners, **kwargs)
    if add_dim:
        output = output.squeeze(3)
    return output 
Example #26
Source File: utils.py    From airlab with Apache License 2.0 5 votes vote down vote up
def diffeomorphic_2D(displacement, grid, scaling=-1):

        if scaling < 0:
            scaling = Diffeomorphic._compute_scaling_value(displacement)

        displacement = displacement / (2 ** scaling)

        displacement = displacement.transpose(2, 1).transpose(1, 0).unsqueeze(0)

        for i in range(scaling):
            displacement_trans = displacement.transpose(1, 2).transpose(2, 3)
            displacement = displacement + F.grid_sample(displacement, displacement_trans + grid)

        return displacement.transpose(1, 2).transpose(2, 3).squeeze() 
Example #27
Source File: utils.py    From airlab with Apache License 2.0 5 votes vote down vote up
def warp_image(image, displacement):

    image_size = image.size

    grid = compute_grid(image_size, dtype=image.dtype, device=image.device)

    # warp image
    warped_image = F.grid_sample(image.image, displacement + grid)

    return iutils.Image(warped_image, image_size, image.spacing, image.origin) 
Example #28
Source File: net_utils.py    From Distilling-Object-Detectors with MIT License 5 votes vote down vote up
def compare_grid_sample():
    # do gradcheck
    N = random.randint(1, 8)
    C = 2 # random.randint(1, 8)
    H = 5 # random.randint(1, 8)
    W = 4 # random.randint(1, 8)
    input = Variable(torch.randn(N, C, H, W).cuda(), requires_grad=True)
    input_p = input.clone().data.contiguous()
   
    grid = Variable(torch.randn(N, H, W, 2).cuda(), requires_grad=True)
    grid_clone = grid.clone().contiguous()

    out_offcial = F.grid_sample(input, grid)    
    grad_outputs = Variable(torch.rand(out_offcial.size()).cuda())
    grad_outputs_clone = grad_outputs.clone().contiguous()
    grad_inputs = torch.autograd.grad(out_offcial, (input, grid), grad_outputs.contiguous())
    grad_input_off = grad_inputs[0]


    crf = RoICropFunction()
    grid_yx = torch.stack([grid_clone.data[:,:,:,1], grid_clone.data[:,:,:,0]], 3).contiguous().cuda()
    out_stn = crf.forward(input_p, grid_yx)
    grad_inputs = crf.backward(grad_outputs_clone.data)
    grad_input_stn = grad_inputs[0]
    pdb.set_trace()

    delta = (grad_input_off.data - grad_input_stn).sum() 
Example #29
Source File: utils.py    From airlab with Apache License 2.0 5 votes vote down vote up
def diffeomorphic_3D(displacement, grid, scaling=-1):
        displacement = displacement / (2 ** scaling)

        displacement = displacement.transpose(3, 2).transpose(2, 1).transpose(0, 1).unsqueeze(0)

        for i in range(scaling):
            displacement_trans = displacement.transpose(1, 2).transpose(2, 3).transpose(3, 4)
            displacement = displacement + F.grid_sample(displacement, displacement_trans + grid)

        return displacement.transpose(1, 2).transpose(2, 3).transpose(3, 4).squeeze() 
Example #30
Source File: net_utils.py    From fpn.pytorch with MIT License 5 votes vote down vote up
def compare_grid_sample():
    # do gradcheck
    N = random.randint(1, 8)
    C = 2 # random.randint(1, 8)
    H = 5 # random.randint(1, 8)
    W = 4 # random.randint(1, 8)
    input = Variable(torch.randn(N, C, H, W).cuda(), requires_grad=True)
    input_p = input.clone().data.contiguous()

    grid = Variable(torch.randn(N, H, W, 2).cuda(), requires_grad=True)
    grid_clone = grid.clone().contiguous()

    out_offcial = F.grid_sample(input, grid)
    grad_outputs = Variable(torch.rand(out_offcial.size()).cuda())
    grad_outputs_clone = grad_outputs.clone().contiguous()
    grad_inputs = torch.autograd.grad(out_offcial, (input, grid), grad_outputs.contiguous())
    grad_input_off = grad_inputs[0]


    crf = RoICropFunction()
    grid_yx = torch.stack([grid_clone.data[:,:,:,1], grid_clone.data[:,:,:,0]], 3).contiguous().cuda()
    out_stn = crf.forward(input_p, grid_yx)
    grad_inputs = crf.backward(grad_outputs_clone.data)
    grad_input_stn = grad_inputs[0]
    pdb.set_trace()

    delta = (grad_input_off.data - grad_input_stn).sum()