import torch import torch.nn as nn from config import cfg from utils.network_utils import * # # Disparity Loss # def EPE(output, target, occ_mask): N = torch.sum(occ_mask) d_diff = output - target EPE_map = torch.abs(d_diff) EPE_map = torch.mul(EPE_map, occ_mask) EPE_mean = torch.sum(EPE_map)/N return EPE_mean def multiscaleLoss(outputs, target, img, occ_mask, weights): def one_scale(output, target, occ_mask): b, _, h, w = output.size() occ_mask = nn.functional.adaptive_max_pool2d(occ_mask, (h, w)) if cfg.DATASET.SPARSE: target_scaled = nn.functional.adaptive_max_pool2d(target, (h, w)) else: target_scaled = nn.functional.adaptive_avg_pool2d(target, (h, w)) return EPE(output, target_scaled, occ_mask) if type(outputs) not in [tuple, list]: outputs = [outputs] assert(len(weights) == len(outputs)) loss = 0 for output, weight in zip(outputs, weights): loss += weight * one_scale(output, target, occ_mask) return loss def realEPE(output, target, occ_mask): b, _, h, w = target.size() upsampled_output = nn.functional.interpolate(output, size=(h,w), mode = 'bilinear', align_corners=True) return EPE(upsampled_output, target, occ_mask) # # Deblurring Loss # def mseLoss(output, target): mse_loss = nn.MSELoss(reduction ='elementwise_mean') MSE = mse_loss(output, target) return MSE def PSNR(output, target, max_val = 1.0): output = output.clamp(0.0,1.0) mse = torch.pow(target - output, 2).mean() if mse == 0: return torch.Tensor([100.0]) return 10 * torch.log10(max_val**2 / mse) def perceptualLoss(fakeIm, realIm, vggnet): ''' use vgg19 conv1_2, conv2_2, conv3_3 feature, before relu layer ''' weights = [1, 0.2, 0.04] features_fake = vggnet(fakeIm) features_real = vggnet(realIm) features_real_no_grad = [f_real.detach() for f_real in features_real] mse_loss = nn.MSELoss(reduction='elementwise_mean') loss = 0 for i in range(len(features_real)): loss_i = mse_loss(features_fake[i], features_real_no_grad[i]) loss = loss + loss_i * weights[i] return loss