import sys sys.path.append("../") import torch import torch.nn.functional as F from torch.autograd import Variable from SSIM import SSIM from utils.imwrap import imwrap_BCHW from utils.utils import imsplot_tensor import logging #logging.basicConfig(level=logging.DEBUG, format=' %(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format=' %(asctime)s - %(levelname)s - %(message)s') flag_test = False flag_imshow = False def create_impyramid(im, levels): impyramid = [im] # pyramid for i in range(1, levels): impyramid.append(impyramid[-1][:, :, ::2, ::2]) return impyramid class loss_stereo(torch.nn.Module): def __init__(self): super(loss_stereo, self).__init__() self.w_ap = 1.0 self.w_ds = 0.001 self.w_lr = 0.001 self.w_m = 0.0001 self.ssim = SSIM() def wfun(self, similarity): return max(0, similarity - 0.75)/2 + 0.001 def diff1_dx(self, img): assert len(img.shape) == 4 diff1 = img[:,:,:,1:] - img[:,:,:,:-1] return F.pad(diff1, [0,1,0,0]) def diff1_dy(self, img): assert len(img.shape) == 4 diff1 = img[:,:,1:] - img[:,:,:-1] return F.pad(diff1, [0,0,0,1]) def diff2_dx(self, img): assert len(img.shape) == 4 diff2 = img[:,:,:,2:] + img[:,:,:,:-2] - img[:,:,:,1:-1] - img[:,:,:,1:-1] return F.pad(diff2, [1,1,0,0]) def diff2_dy(self, img): assert len(img.shape) == 4 diff2 = img[:,:,2:] + img[:,:,:-2] - img[:,:,1:-1] - img[:,:,1:-1] return F.pad(diff2, [0,0,1,1]) def diff_z_dx(self, disp): assert len(disp.shape) == 4 diff_p = (disp[:,:,:,1:-1]/disp[:,:,:,2:]) + (disp[:,:,:,1:-1]/disp[:,:,:,:-2]) - 2 return F.pad(diff_p, [1,1,0,0]) def diff_z_dy(self, disp): assert len(disp.shape) == 4 diff_p = (disp[:,:,1:-1]/disp[:,:,2:]) + (disp[:,:,1:-1]/disp[:,:,:-2]) - 2 return F.pad(diff_p, [0,0,1,1]) def C_imdiff1(self, img, img_wrap): L1_dx = torch.abs(self.diff1_dx(img) - self.diff1_dx(img_wrap)) L1_dy = torch.abs(self.diff1_dy(img) - self.diff1_dy(img_wrap)) return L1_dx + L1_dy def C_ds1(self, img, disp): disp_dx = torch.abs(self.diff1_dx(disp)) disp_dy = torch.abs(self.diff1_dy(disp)) image_dx = torch.abs(self.diff1_dx(img)) image_dy = torch.abs(self.diff1_dy(img)) weights_x = torch.exp(-torch.sum(image_dx, dim=1, keepdim=True)) weights_y = torch.exp(-torch.sum(image_dy, dim=1, keepdim=True)) #print weights_x.shape, disp_gradients_x.shape smoothness_x = disp_dx * weights_x smoothness_y = disp_dy * weights_y return smoothness_x + smoothness_y def C_ds2(self, img, disp): disp_dx = torch.abs(self.diff2_dx(disp)) disp_dy = torch.abs(self.diff2_dy(disp)) image_dx = torch.abs(self.diff2_dx(img)) image_dy = torch.abs(self.diff2_dy(img)) weights_x = torch.exp(-torch.sum(image_dx, dim=1, keepdim=True)) weights_y = torch.exp(-torch.sum(image_dy, dim=1, keepdim=True)) #print weights_x.shape, disp_gradients_x.shape smoothness_x = disp_dx * weights_x smoothness_y = disp_dy * weights_y return smoothness_x + smoothness_y def C_ds3(self, img, disp): disp = torch.abs(disp) + 1 disp_dx = torch.abs(self.diff_z_dx(disp)).clamp(0, 10) disp_dy = torch.abs(self.diff_z_dy(disp)).clamp(0, 10) image_dx = torch.abs(self.diff1_dx(img)) image_dy = torch.abs(self.diff1_dy(img)) mImage_dx = image_dx.mean(-1,True).mean(-2,True).mean(-3,True) mImage_dy = image_dy.mean(-1,True).mean(-2,True).mean(-3,True) weights_x = torch.exp(-torch.max(image_dx, dim=1, keepdim=True)[0]/(0.5*mImage_dx)) weights_y = torch.exp(-torch.max(image_dy, dim=1, keepdim=True)[0]/(0.5*mImage_dy)) #print weights_x.shape, disp_gradients_x.shape smoothness_x = disp_dx * weights_x smoothness_y = disp_dy * weights_y return smoothness_x + smoothness_y def C_ds3t1(self, img, disp): disp_dx = torch.abs(self.diff1_dx(disp)) disp_dy = torch.abs(self.diff1_dy(disp)) image_dx = torch.abs(self.diff1_dx(img)) image_dy = torch.abs(self.diff1_dy(img)) mImage_dx = image_dx.mean(-1,True).mean(-2,True).mean(-3,True) mImage_dy = image_dy.mean(-1,True).mean(-2,True).mean(-3,True) weights_x = torch.exp(-torch.max(image_dx, dim=1, keepdim=True)[0]/(0.5*mImage_dx)) weights_y = torch.exp(-torch.max(image_dy, dim=1, keepdim=True)[0]/(0.5*mImage_dy)) #print weights_x.shape, disp_gradients_x.shape smoothness_x = disp_dx * weights_x smoothness_y = disp_dy * weights_y return smoothness_x + smoothness_y def C_ds3t(self, img, disp): disp = torch.abs(disp) + 1 disp_dx = torch.abs(self.diff_z_dx(disp)).clamp(0, 10) disp_dy = torch.abs(self.diff_z_dy(disp)).clamp(0, 10) image_dx = torch.abs(self.diff1_dx(img)) image_dy = torch.abs(self.diff1_dy(img)) mImage_dx = image_dx.mean(-1,True).mean(-2,True).mean(-3,True) mImage_dy = image_dy.mean(-1,True).mean(-2,True).mean(-3,True) weights_x = torch.exp(-torch.max(image_dx, dim=1, keepdim=True)[0]/(0.5*mImage_dx)) weights_y = torch.exp(-torch.max(image_dy, dim=1, keepdim=True)[0]/(0.5*mImage_dy)) #print weights_x.shape, disp_gradients_x.shape smoothness_x = disp_dx * weights_x smoothness_y = disp_dy * weights_y return smoothness_x + smoothness_y class lossfun(loss_stereo): def __init__(self, loss_name): super(lossfun, self).__init__() self.loss_name = loss_name def loss_common(self, im, im_wrap, disp, disp_wrap, factor=1.0, weight_common=None): # ----------------set w_ds and w_lr--------------------- mask_ap = (im_wrap[:, :1] != 0).detach() if(len(mask_ap[mask_ap]) < 1024): mask_ap[:] = 1 img_ssim = self.ssim(im, im_wrap) simlary = img_ssim[mask_ap].mean().data[0] w = self.wfun(simlary) self.w_ds = w self.w_lr = w # ----------------set C_ap and C_lr--------------------- C_ap = (0.85*0.5)*(1 - img_ssim) + 0.15*(torch.abs(im - im_wrap)) # + self.C_imdiff1(im, im_wrap)) C_lr = torch.abs(disp - disp_wrap) # ---------------------set mask------------------------ if(weight_common is not None): mask_im = ((disp_wrap==0) + mask_ap).detach() > 1 mask_lr = (disp_wrap==0).detach() weight_im = weight_common.clone() weight_im[mask_im] = 1.0 weight_lr = weight_common.clone() weight_lr[mask_lr] = 0 C_ap = C_ap * weight_im C_lr = C_lr * weight_lr # mask_lr.float() msg = "weight_im maxV: %f, minV: %f ;" % (weight_im.max().data[0], weight_im.min().data[0]) msg += "weight_lr maxV: %f, minV: %f " % (weight_lr.max().data[0], weight_lr.min().data[0]) logging.debug(msg) # ----------------------C_all---------------------------- C_ap = C_ap.mean() C_ds = self.C_ds3(im, disp).mean() C_lr = C_lr.mean() C = C_ap*self.w_ap + C_ds*(self.w_ds) + C_lr*self.w_lr # show in screen if(flag_test): print self.w_ap, self.w_ds, self.w_lr, simlary print C.data[0], C_ap.data[0]*self.w_ap, C_ds.data[0]*self.w_ds, C_lr.data[0]*self.w_lr return C def loss_depthmono(self, im, im_wrap, disp, disp_wrap, factor=1.0, weight_common=None): # ----------------set w_ds and w_lr--------------------- img_ssim = self.ssim(im, im_wrap) mask_ap = (im_wrap[:, :1] != 0).detach() if(len(mask_ap[mask_ap]) < 1024): mask_ap[:] = 1 simlary = img_ssim[mask_ap].mean().data[0] w = self.wfun(simlary) self.w_ds = w self.w_lr = w # ----------------set C_ap and C_lr--------------------- C_ap = (0.85*0.5)*(1 - img_ssim) + 0.15*torch.abs(im - im_wrap) C_lr = torch.abs(disp - disp_wrap) # ---------------------set mask------------------------ if(weight_common is not None): mask_im = ((disp_wrap==0) + mask_ap).detach() > 1 mask_lr = (disp_wrap==0) weight_im = weight_common.clone() weight_im[mask_im] = 1.0 weight_lr = weight_common.clone() weight_lr[mask_lr] = 0 C_ap = C_ap * weight_im C_lr = C_lr * weight_lr msg = "weight_im maxV: %f, minV: %f ;" % (weight_im.max().data[0], weight_im.min().data[0]) msg += "weight_lr maxV: %f, minV: %f " % (weight_lr.max().data[0], weight_lr.min().data[0]) logging.debug(msg) # ----------------------C_all---------------------------- C_ap = C_ap.mean() C_ds = self.C_ds1(im, disp).mean() C_lr = C_lr.mean() C = C_ap*self.w_ap + C_ds*(self.w_ds) + C_lr*self.w_lr # show in screen if(flag_test): print self.w_ap, self.w_ds, self.w_lr, simlary print C.data[0], C_ap.data[0]*self.w_ap, C_ds.data[0]*self.w_ds, C_lr.data[0]*self.w_lr return C def loss_Cap_ds_lr(self, im, im_wrap, disp, disp_wrap, factor=1.0, weight_common=None): # ----------------set w_ds and w_lr--------------------- img_ssim = self.ssim(im, im_wrap) mask_ap = (im_wrap[:, :1] != 0).detach() simlary = img_ssim[mask_ap].mean().data[0] w = self.wfun(simlary) self.w_ds = w self.w_lr = w # ----------------set C_ap and C_lr--------------------- C_ap = (0.85*0.5)*(1 - img_ssim) + 0.15*torch.abs(im - im_wrap) C_lr = torch.abs(disp - disp_wrap) # ---------------------set mask------------------------ if(weight_common is not None): mask_im = ((disp_wrap==0) + mask_ap).detach() > 1 mask_lr = (disp_wrap==0) weight_im = weight_common.clone() weight_im[mask_im] = 1.0 weight_lr = weight_common.clone() weight_lr[mask_lr] = 0 C_ap = C_ap * weight_im C_lr = C_lr * weight_lr msg = "weight_im maxV: %f, minV: %f ;" % (weight_im.max().data[0], weight_im.min().data[0]) msg += "weight_lr maxV: %f, minV: %f " % (weight_lr.max().data[0], weight_lr.min().data[0]) logging.debug(msg) # ----------------------C_ap---------------------------- C_ap = C_ap.mean() C = C_ap * self.w_ap # ----------------------C_ds---------------------------- if("ds" in self.loss_name): C_ds = self.C_ds1(im, disp).mean() C += C_ds * (self.w_ds/factor) # ----------------------C_lr---------------------------- if("lr" in self.loss_name): C_lr = C_lr.mean() C += C_lr * self.w_lr # show in screen if(flag_test): print self.w_ap, self.w_ds, self.w_lr, simlary print C.data[0], C_ap.data[0]*self.w_ap, C_ds.data[0]*self.w_ds, C_lr.data[0]*self.w_lr return C def loss_SsSMnet(self, im, im_wrap, im_wrap1, disp, factor=1.0, weight_common=None): # ----------------set w_ds and w_lr--------------------- img_ssim = self.ssim(im, im_wrap) mask_ap = (im_wrap[:, :1] != 0).detach() simlary = img_ssim[mask_ap].mean().data[0] w = self.wfun(simlary) self.w_ds = w self.w_lr = w # ----------------set C_ap and C_lr--------------------- C_ap = (0.85*0.5)*(1 - img_ssim) + 0.15*(torch.abs(im - im_wrap) + self.C_imdiff1(im, im_wrap)) C_lr = torch.abs(im - im_wrap1) # ---------------------set mask------------------------ if(weight_common is not None): mask_im = ((im_wrap1[:, :1] == 0) + mask_ap).detach() > 1 mask_lr = (im_wrap1[:, :1] == 0) weight_im = weight_common.clone() weight_im[mask_im] = 1.0 weight_lr = weight_common.clone() weight_lr[mask_lr] = 0 C_ap = C_ap * weight_im C_lr = C_lr * weight_lr msg = "weight_im maxV: %f, minV: %f ;" % (weight_im.max().data[0], weight_im.min().data[0]) msg += "weight_lr maxV: %f, minV: %f " % (weight_lr.max().data[0], weight_lr.min().data[0]) logging.debug(msg) # ----------------------C_all---------------------------- C_ap = C_ap.mean() C_ds = self.C_ds2(im, disp).mean() C_lr = C_lr.mean() C_mdh = torch.abs(disp).mean() C = C_ap*self.w_ap + C_ds*(self.w_ds/factor) + C_lr*self.w_lr + C_mdh*self.w_m # show in screen if(flag_test): print self.w_ap, self.w_ds, self.w_lr, simlary print C.data[0], C_ap.data[0]*self.w_ap, C_ds.data[0]*self.w_ds, C_lr.data[0]*self.w_lr return C def loss_supervised(self, disp_gt, disp, flag_smooth=False, factor=1.0): mask = disp_gt>0 if(len(mask[mask])==0): return 0 loss = torch.abs(disp_gt - disp)[mask].mean() #loss = loss + 0.001*torch.abs(disp[disp_gt==0]).mean() if(flag_smooth): disp_dx = self.diff1_dx(disp) disp_dy = self.diff1_dy(disp) disp_dxdy = (torch.abs(disp_dx) + torch.abs(disp_dy))/factor C_smooth = disp_dxdy[mask].clamp(0, 1).mean() loss = loss + 0.1*C_smooth return loss class losses(lossfun): def __init__(self, loss_name="supervised", count_levels=1, maxepoch_weight_adjust=1): # loss_name parse self.flag_mask = ("mask" in loss_name) loss_name = loss_name.split("-")[0].lower() self.loss_names = ["supervised", "depthmono", "SsSMnet".lower(), "Cap_ds_lr".lower(), "common"] assert loss_name in self.loss_names or "Cap".lower() in loss_name # set lossfun and lossesfun super(losses, self).__init__(loss_name) self.lossfun = None self.lossesfun = None self.setlossfun(loss_name) # weight_levels self.maxepoch_weight_adjust = maxepoch_weight_adjust self.count_levels = count_levels self.weight_levels = [0]*count_levels self.weight_levels[-1] = 1 def setlossfun(self, loss_name): if(self.loss_names[0] in loss_name): self.lossfun = self.loss_supervised self.lossesfun = self.losses_pyramid0 elif(self.loss_names[1] in loss_name): self.lossfun = self.loss_depthmono self.lossesfun = self.losses_pyramid1 elif(self.loss_names[2] in loss_name): self.lossfun = self.loss_SsSMnet self.lossesfun = self.losses_pyramid2 elif("Cap".lower() in loss_name): self.lossfun = self.loss_Cap_ds_lr self.lossesfun = self.losses_pyramid1 elif(self.loss_names[4] in loss_name): self.lossfun = self.loss_common self.lossesfun = self.losses_pyramid1 def Weight_Adjust_levels(self, epoch): count_level = self.count_levels maxepoch = self.maxepoch_weight_adjust self.weight_levels = [0.01]*count_level if(count_level == 1 or epoch >= maxepoch): self.weight_levels[0] = 1 return x = (1 - epoch/float(maxepoch))*(count_level - 1) idx = int(x) w = x - idx self.weight_levels[idx] = 1 - w if(idx < count_level-1): self.weight_levels[idx+1] = w def weight_common(self, disp, disp_wrap, factor=1.0): disp_delt = torch.abs(disp - disp_wrap).detach()/factor weight = Variable(torch.zeros(disp_delt.shape), requires_grad=False).type_as(disp_delt) mask1 = disp_delt<1 mask2 = (disp_delt<3) - mask1 mask3 = disp_delt >= 3 weight[mask1] = 1.0 weight[mask2] = 1.0 - (disp_delt[mask2] - 1)*(0.99/2) weight[mask3] = 0.01 msg = "weight maxV: %f, minV: %f" % (weight.max().data[0], weight.min().data[0]) logging.debug(msg) return weight # losses for loss_supervised def losses_pyramid0(self, disp_gt, disps, scale_disps, flag_smooth=False): count = len(scale_disps) _, _, h, w = disp_gt.shape loss = 0 for i in range(0, count): level = scale_disps[i] weight = self.weight_levels[level] if(weight <= 0): continue if(level > 0): pred = F.upsample(disps[i], scale_factor=2**level, mode='bilinear')[:, :, :h, :w] else: pred = disps[i] loss = loss + self.lossfun(disp_gt, pred, flag_smooth, factor=1)*weight return loss # losses for depthmono/common/Cap_ds_lr def losses_pyramid1(self, imR_src, imL, dispLs, scale_dispLs, LeftTop, imR1_src, imL1, dispL1s, scale_dispL1s, LeftTop1): count = len(scale_dispLs) # count of output maxlevel = min(2, max(scale_dispLs)) for i in range(0, count): if(scale_dispLs[i] == maxlevel): _, _, h, w = dispLs[maxlevel].shape imLs = create_impyramid(imL, maxlevel + 1) imL1s = create_impyramid(imL1, maxlevel + 1) # compute loss loss = 0 for i in range(0, count): level = scale_dispLs[i] weight = self.weight_levels[level] if(weight <= 0): continue if(level > maxlevel): scale_factor = 2**maxlevel dispL = F.upsample(dispLs[i], scale_factor=2**(level - maxlevel), mode='bilinear')[:, :, :h, :w] dispL1 = F.upsample(dispL1s[i], scale_factor=2**(level - maxlevel), mode='bilinear')[:, :, :h, :w] else: scale_factor = 2**level dispL = dispLs[i] dispL1 = dispL1s[i] weight_common = None weight_common1 = None imL_wrap = imwrap_BCHW(imR_src, dispL, fliplr=False, LeftTop=LeftTop, scale_factor=scale_factor) imL1_wrap = imwrap_BCHW(imR1_src, dispL1, fliplr=False, LeftTop=LeftTop1, scale_factor=scale_factor) dispL_wrap = imwrap_BCHW(dispL1, dispL, fliplr=True, LeftTop=[0, 0], scale_factor=1) dispL1_wrap = imwrap_BCHW(dispL, dispL1, fliplr=True, LeftTop=[0, 0], scale_factor=1) if(self.flag_mask): weight_common = self.weight_common(dispL, dispL_wrap, factor=scale_factor) weight_common1 = self.weight_common(dispL1, dispL1_wrap, factor=scale_factor) tmp = self.lossfun(imLs[min(level, maxlevel)], imL_wrap, dispL, dispL_wrap, factor=(2**level), weight_common=weight_common) tmp1 = self.lossfun(imL1s[min(level, maxlevel)], imL1_wrap, dispL1, dispL1_wrap, factor=(2**level), weight_common=weight_common1) loss = loss + (tmp + tmp1)*weight # / (4**level) # imshow if(flag_imshow and (i==count-1)): imsplot_tensor(imL, imL1, imL_wrap, imL1_wrap, dispLs[0], dispL1s[0], dispL_wrap, dispL1_wrap) import matplotlib.pyplot as plt plt.savefig("tmp_check.png") return loss # losses for SsSMnet def losses_pyramid2(self, imR_src, imL, dispLs, scale_dispLs, LeftTop, imR1_src, imL1, dispL1s, scale_dispL1s, LeftTop1): count = len(scale_dispLs) # count of output maxlevel = min(2, max(scale_dispLs)) for i in range(0, count): if(scale_dispLs[i] == maxlevel): _, _, h, w = dispLs[maxlevel].shape imLs = create_impyramid(imL, maxlevel + 1) imL1s = create_impyramid(imL1, maxlevel + 1) # compute loss loss = 0 for i in range(0, count): level = scale_dispLs[i] weight = self.weight_levels[level] if(weight == 0): continue if(level > maxlevel): scale_factor = 2**maxlevel dispL = F.upsample(dispLs[i], scale_factor=2**(level - maxlevel), mode='bilinear')[:, :, :h, :w] dispL1 = F.upsample(dispL1s[i], scale_factor=2**(level - maxlevel), mode='bilinear')[:, :, :h, :w] else: scale_factor = 2**level dispL = dispLs[i] dispL1 = dispL1s[i] weight_common = None weight_common1 = None imL_wrap = imwrap_BCHW(imR_src, dispL, fliplr=False, LeftTop=LeftTop, scale_factor=scale_factor) imL1_wrap = imwrap_BCHW(imR1_src, dispL1, fliplr=False, LeftTop=LeftTop1, scale_factor=scale_factor) imL_wrap1 = imwrap_BCHW(imL1_wrap, dispL, fliplr=True, LeftTop=[0, 0], scale_factor=1) imL1_wrap1 = imwrap_BCHW(imL_wrap, dispL1, fliplr=True, LeftTop=[0, 0], scale_factor=1) if(self.flag_mask): dispL_wrap = imwrap_BCHW(dispL1, dispL, fliplr=True, LeftTop=[0, 0], scale_factor=1) dispL1_wrap = imwrap_BCHW(dispL, dispL1, fliplr=True, LeftTop=[0, 0], scale_factor=1) weight_common = self.weight_common(dispL, dispL_wrap, factor=scale_factor) weight_common1 = self.weight_common(dispL1, dispL1_wrap, factor=scale_factor) tmp = self.lossfun(imLs[min(level, maxlevel)], imL_wrap, imL_wrap1, dispL, factor=(2**level), weight_common=weight_common) tmp1 = self.lossfun(imL1s[min(level, maxlevel)], imL1_wrap, imL1_wrap1, dispL1, factor=(2**level), weight_common=weight_common1) loss = loss + (tmp + tmp1)*weight # / (4**level) # imshow if(flag_imshow and (i==count-1)): imsplot_tensor(imL, imL1, imL_wrap, imL1_wrap, imL_wrap1, imL1_wrap1, dispLs[0], dispL1s[0]) import matplotlib.pyplot as plt plt.savefig("tmp_check.png") return loss def forward(self, args): return self.lossesfun(**args)