#!/usr/bin/env python # -*- coding: UTF-8 -*- import torch from torch.autograd import Variable import torch.nn.functional as F def imswrap(im_src, disps, scale_disps, fliplr=False, LeftTop=[0, 0]): assert type(disps) is list or type(scale_disps) is list assert len(disps) == len(scale_disps) count = len(scale_disps) maxLevel = max(scale_disps) ims_src = [im_src] for i in range(0, maxLevel): ims_src.append(F.avg_pool2d(ims_src[-1], 3, 2, 1)) ims_wrap =[] for i in range(count): level = scale_disps[i] scale_factor = 2.0**level LeftTop[0] = LeftTop[0]/scale_factor LeftTop[1] = LeftTop[1]/scale_factor im = imwrap_BCHW(ims_src[level], disps[i], fliplr, LeftTop, 1) ims_wrap.append(im) return ims_wrap def imwrap_pyramid(im_src, disps_pyramid, fliplr=False, LeftTop=[0, 0]): assert type(disps_pyramid) is list levels = len(disps_pyramid) ims_wrap =[] scale_factor = 1 for i in range(levels): im = imwrap_BCHW(im_src, disps_pyramid[i], fliplr, LeftTop, scale_factor) ims_wrap.append(im) scale_factor = scale_factor*2 return ims_wrap def imwrap_BCHW(im_src, disp, fliplr=False, LeftTop=[0, 0], scale_factor=1): ''' the shape of im_src should be (bn, c , h0, w0) the shape of disp should be (bn, 1 , h, w) fliplr is the flag of flip im_src horizontally LeftTop is the imwrap's left top position in im_src_fliplr scale_factor is rate of scale between imwrap and im_src ''' # imwrap bn, _, h0, w0 = im_src.shape bn, c, h, w = disp.shape assert c == 1 and min(h, w, h0, w0)>1 # ------------------------compute area(x, x1, y, y1)------------------------------ x, y = LeftTop x = x*2.0/(w0 - 1) - 1 # use (w0-1) because the boundary is the center of pixel y = y*2.0/(h0 - 1) - 1 x1 = x + (w - 1)*scale_factor*2.0/(w0 - 1) y1 = y + (h - 1)*scale_factor*2.0/(h0 - 1) #print x, x1, y, y1 # ---------------------------create sample grid------------------------------------- row = torch.linspace(x, x1, w) col = torch.linspace(y, y1, h) grid = torch.zeros(bn, h, w, 2) for n in range(bn): for i in range(h): grid[n, i, :, 0] = row for i in range(w): grid[n, :, i, 1] = col grid = Variable(grid, requires_grad=False).type_as(im_src) k = -1.0 if fliplr else 1 grid[:, :, :, 0] = k*(grid[:, :, :, 0] - disp.squeeze(1)*2.0/(w0 - 1)) #print grid.shape, type(grid), type(im_src) # ---------------------------sample image by grid----------------------------------- delt = 1e-4*(torch.rand(1)[0] + 0.1) im_wrap = F.grid_sample(im_src + delt, grid) return im_wrap def imwrap_BCHW0(im_src, disp): # imwrap bn, c, h, w = im_src.shape row = torch.linspace(-1, 1, w) col = torch.linspace(-1, 1, h) grid = torch.zeros(bn, h, w, 2) for n in range(bn): for i in range(h): grid[n, i, :, 0] = row for i in range(w): grid[n, :, i, 1] = col grid = Variable(grid, requires_grad=True).type_as(im_src) grid[:, :, :, 0] = grid[:, :, :, 0] - disp.squeeze(1)*2/w #print disp[-1, -1, -1], grid[-1, -1, -1, 0] im_src.clamp(min=1e-6) im_wrap = F.grid_sample(im_src, grid) return im_wrap