import torch from torch import nn from torch.autograd import Variable from torch.autograd import grad from src import utils class DiscriminatorLoss(nn.Module): def __init__(self, opt): super(DiscriminatorLoss, self).__init__() self.gpu_id = opt.gpu_ids[0] # Adversarial criteria for the predictions if opt.dis_adv_loss_type == 'gan': self.crit = nn.BCEWithLogitsLoss() elif opt.dis_adv_loss_type == 'lsgan': self.crit = nn.MSELoss() # Targets for criteria self.labels_real = [] self.labels_fake = [] # Iterate over discriminators to inialize labels for size in opt.dis_output_sizes: shape = (opt.batch_size, 1, size, size) self.labels_real += [Variable(torch.ones(shape).cuda(self.gpu_id))] self.labels_fake += [Variable(torch.zeros(shape).cuda(self.gpu_id))] def __call__(self, dis, img_real_dst, img_fake_dst=None, aux_real_dst=None, img_real_src=None, enc=None): # Preds for real (during dis backprop) or fake (during gen backprop) outputs_real = dis(img_real_dst, img_real_src, enc) # Preds for fake during dis backprop if img_fake_dst is not None: outputs_fake = dis(img_fake_dst, img_real_src, enc) loss = 0 # loss losses_adv = [] # losses for each discriminator output for i in range(len(outputs_real)): losses_adv += [self.crit(outputs_real[i], self.labels_real[i])] if img_fake_dst is not None: losses_adv[-1] += self.crit(outputs_fake[i], self.labels_fake[i]) losses_adv[-1] *= 0.5 loss += losses_adv[-1] # Get loss values losses_adv = [loss_adv.data.item() for loss_adv in losses_adv] losses_adv = [sum(losses_adv)] return loss, losses_adv