import torch import torch.nn as nn import torch.nn.functional as F ## # version 1: use torch.autograd class FocalLossV1(nn.Module): def __init__(self, alpha=0.25, gamma=2, reduction='mean',): super(FocalLossV1, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction self.crit = nn.BCEWithLogitsLoss(reduction='none') def forward(self, logits, label): ''' args: logits: tensor of shape (N, ...) label: tensor of shape(N, ...) ''' # compute loss logits = logits.float() # use fp32 if logits is fp16 with torch.no_grad(): alpha = torch.empty_like(logits).fill_(1 - self.alpha) alpha[label == 1] = self.alpha probs = torch.sigmoid(logits) pt = torch.where(label == 1, probs, 1 - probs) ce_loss = self.crit(logits, label.double()) loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss) if self.reduction == 'mean': loss = loss.mean() if self.reduction == 'sum': loss = loss.sum() return loss ## # version 2: user derived grad computation class FocalSigmoidLossFuncV2(torch.autograd.Function): ''' compute backward directly for better numeric stability ''' @staticmethod def forward(ctx, logits, label, alpha, gamma): logits = logits.float() coeff = torch.empty_like(logits).fill_(1 - alpha) coeff[label == 1] = alpha probs = torch.sigmoid(logits) log_probs = torch.where(logits >= 0, F.softplus(logits, -1, 50), logits - F.softplus(logits, 1, 50)) log_1_probs = torch.where(logits >= 0, -logits + F.softplus(logits, -1, 50), -F.softplus(logits, 1, 50)) probs_gamma = probs ** gamma probs_1_gamma = (1. - probs) ** gamma ctx.coeff = coeff ctx.probs = probs ctx.log_probs = log_probs ctx.log_1_probs = log_1_probs ctx.probs_gamma = probs_gamma ctx.probs_1_gamma = probs_1_gamma ctx.label = label ctx.gamma = gamma term1 = probs_1_gamma * log_probs term2 = probs_gamma * log_1_probs loss = torch.where(label == 1, term1, term2).mul_(coeff).neg_() return loss @staticmethod def backward(ctx, grad_output): ''' compute gradient of focal loss ''' coeff = ctx.coeff probs = ctx.probs log_probs = ctx.log_probs log_1_probs = ctx.log_1_probs probs_gamma = ctx.probs_gamma probs_1_gamma = ctx.probs_1_gamma label = ctx.label gamma = ctx.gamma term1 = (1. - probs - gamma * probs * log_probs).mul_(probs_1_gamma).neg_() term2 = (probs - gamma * (1. - probs) * log_1_probs).mul_(probs_gamma) grads = torch.where(label == 1, term1, term2).mul_(coeff).mul_(grad_output) return grads, None, None, None class FocalLossV2(nn.Module): ''' This use better formula to compute the gradient, which has better numeric stability ''' def __init__(self, alpha=0.25, gamma=2, reduction='mean'): super(FocalLossV2, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, logits, label): loss = FocalSigmoidLossFuncV2.apply(logits, label, self.alpha, self.gamma) if self.reduction == 'mean': loss = loss.mean() if self.reduction == 'sum': loss = loss.sum() return loss ## # version 3: implement wit cpp/cuda to save memory and accelerate import focal_cpp # import torch before import cpp extension class FocalSigmoidLossFuncV3(torch.autograd.Function): ''' use cpp/cuda to accelerate and shrink memory usage ''' @staticmethod def forward(ctx, logits, labels, alpha, gamma): logits = logits.float() loss = focal_cpp.focalloss_forward(logits, labels, gamma, alpha) ctx.variables = logits, labels, alpha, gamma return loss @staticmethod def backward(ctx, grad_output): ''' compute gradient of focal loss ''' logits, labels, alpha, gamma = ctx.variables grads = focal_cpp.focalloss_backward(grad_output, logits, labels, gamma, alpha) return grads, None, None, None class FocalLossV3(nn.Module): ''' This use better formula to compute the gradient, which has better numeric stability ''' def __init__(self, alpha=0.25, gamma=2, reduction='mean'): super(FocalLossV3, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, logits, label): loss = FocalSigmoidLossFuncV3.apply(logits, label, self.alpha, self.gamma) if self.reduction == 'mean': loss = loss.mean() if self.reduction == 'sum': loss = loss.sum() return loss if __name__ == '__main__': import torchvision import torch import numpy as np import random torch.manual_seed(15) random.seed(15) np.random.seed(15) torch.backends.cudnn.deterministic = True class Model(nn.Module): def __init__(self): super(Model, self).__init__() net = torchvision.models.resnet18(pretrained=False) self.conv1 = net.conv1 self.bn1 = net.bn1 self.maxpool = net.maxpool self.relu = net.relu self.layer1 = net.layer1 self.layer2 = net.layer2 self.layer3 = net.layer3 self.layer4 = net.layer4 self.out = nn.Conv2d(512, 3, 3, 1, 1) def forward(self, x): feat = self.conv1(x) feat = self.bn1(feat) feat = self.relu(feat) feat = self.maxpool(feat) feat = self.layer1(feat) feat = self.layer2(feat) feat = self.layer3(feat) feat = self.layer4(feat) feat = self.out(feat) out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True) return out net1 = Model() net2 = Model() net2.load_state_dict(net1.state_dict()) criteria1 = FocalLossV1() criteria2 = FocalLossV2() net1.cuda() net2.cuda() net1.train() net2.train() criteria1.cuda() criteria2.cuda() optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2) optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2) bs = 2 for it in range(300000): inten = torch.randn(bs, 3, 224, 244).cuda() lbs = torch.randint(0, 2, (bs, 3, 224, 244)).cuda() logits = net1(inten) loss1 = criteria1(logits, lbs) optim1.zero_grad() loss1.backward() optim1.step() logits = net2(inten) loss2 = criteria2(logits, lbs) optim2.zero_grad() loss2.backward() optim2.step() with torch.no_grad(): if (it+1) % 50 == 0: print('iter: {}, ================='.format(it+1)) print('out.weight: ', torch.mean(torch.abs(net1.out.weight - net2.out.weight)).item()) print('conv1.weight: ', torch.mean(torch.abs(net1.conv1.weight - net2.conv1.weight)).item()) print('loss: ', loss1.item() - loss2.item())