import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable def get_gan_criterion(mode): if mode == 'dcgan': criterion = GANLoss(dis_loss=nn.BCEWithLogitsLoss(),gen_loss=nn.BCEWithLogitsLoss()) elif mode == 'lsgan': criterion = GANLoss(dis_loss=nn.MSELoss(),gen_loss=nn.MSELoss()) elif mode == 'hinge': def hinge_dis(pre, margin): '''margin should not be 0''' logict = (margin>0).float() + (-1. * (margin<0).float()) return torch.mean(F.relu((margin-pre)*logict)) def hinge_gen(pre, margin): return -torch.mean(pre) criterion = GANLoss(real_label=1,fake_label=-1,dis_loss=hinge_dis,gen_loss=hinge_gen) else: raise NotImplementedError('{} is not implementation'.format(mode)) return criterion def get_rec_loss(mode): if mode == 'l1': criterion = nn.L1Loss() elif mode == 'mse': criterion = nn.MSELoss() else: raise NotImplementedError('{} is not implementation'.format(mode)) return criterion def get_kl_loss(): def kl_loss(mu, logvar): KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) loss = torch.sum(KLD_element).mul_(-0.5) return loss return kl_loss class GANLoss(nn.Module): def __init__(self, real_label=1., fake_label=0.,dis_loss=None,gen_loss=None): super(GANLoss, self).__init__() self.real_label = real_label self.fake_label = fake_label self.dis_loss = dis_loss self.gen_loss = gen_loss self.real_label_tensor = None self.fake_label_tensor = None def get_target_tensor(self, input): create_label = self.real_label_tensor is None if not create_label: if isinstance(input,list): for pre,tar in zip(input,self.real_label_tensor): create_label = create_label or pre.numel() != tar.numel() else: create_label = create_label or input.numel() != self.real_label_tensor.numel() if create_label: if isinstance(input,list): self.real_label_tensor = [] self.fake_label_tensor = [] for pre in input: self.real_label_tensor.append(torch.FloatTensor(pre.size()).fill_(self.real_label).to(pre.device)) self.fake_label_tensor.append(torch.FloatTensor(pre.size()).fill_(self.fake_label).to(pre.device)) else: self.real_label_tensor = torch.FloatTensor(input.size()).fill_(self.real_label).to(input.device) self.fake_label_tensor = torch.FloatTensor(input.size()).fill_(self.fake_label).to(input.device) return self.real_label_tensor, self.fake_label_tensor def __call__(self, real=None, fake1=None, fake2=None, weight_real=1, weight_fake1=1, weight_fake2=1): err = 0.0 if not isinstance(real,list): real = [real] fake1 = [fake1] if fake1 is not None else fake1 fake2 = [fake2] if fake2 is not None else fake2 real_label_tensor,fake_label_tensor = self.get_target_tensor(real) if fake1 is not None and fake2 is not None: for r,f1,f2,r_label,f_label in zip(real,fake1,fake2,real_label_tensor,fake_label_tensor): err += self.dis_loss(r*weight_real,r_label*weight_real) + \ self.dis_loss(f1*weight_fake1,f_label*weight_fake1)*0.5 + \ self.dis_loss(f2*weight_fake2,f_label*weight_fake2)*0.5 elif fake1 is not None or fake2 is not None: fake = fake1 if fake1 is not None else fake2 for r,f,r_label,f_label in zip(real,fake,real_label_tensor,fake_label_tensor): err += self.dis_loss(r*weight_real,r_label*weight_real) + \ self.dis_loss(f*weight_fake1,f_label*weight_fake1) else: for r,r_label in zip(real,real_label_tensor): err += self.gen_loss(r*weight_real,r_label*weight_real) return err