# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import torch import torch.nn.functional as F def GDPPLoss(phiFake, phiReal, backward=True): r""" Implementation of the GDPP loss. Can be used with any kind of GAN architecture. Args: phiFake (tensor) : last feature layer of the discriminator on real data phiReal (tensor) : last feature layer of the discriminator on fake data backward (bool) : should we perform the backward operation ? Returns: Loss's value. The backward operation in performed within this operator """ def compute_diversity(phi): phi = F.normalize(phi, p=2, dim=1) SB = torch.mm(phi, phi.t()) eigVals, eigVecs = torch.symeig(SB, eigenvectors=True) return eigVals, eigVecs def normalize_min_max(eigVals): minV, maxV = torch.min(eigVals), torch.max(eigVals) if abs(minV - maxV) < 1e-10: return eigVals return (eigVals - minV) / (maxV - minV) fakeEigVals, fakeEigVecs = compute_diversity(phiFake) realEigVals, realEigVecs = compute_diversity(phiReal) # Scaling factor to make the two losses operating in comparable ranges. magnitudeLoss = 0.0001 * F.mse_loss(target=realEigVals, input=fakeEigVals) structureLoss = -torch.sum(torch.mul(fakeEigVecs, realEigVecs), 0) normalizedRealEigVals = normalize_min_max(realEigVals) weightedStructureLoss = torch.sum( torch.mul(normalizedRealEigVals, structureLoss)) gdppLoss = magnitudeLoss + weightedStructureLoss if backward: gdppLoss.backward(retain_graph=True) return gdppLoss.item()