#!/usr/bin/env python # -*- coding: utf-8 -*- # -------------------------------------------------------- # Licensed under The MIT License [see LICENSE for details] # Written by Chao CHEN (chaochancs@gmail.com) # Created On: 2017-08-11 # -------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): r""" This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) The losses are averaged across observations for each minibatch. Args: alpha(1D Tensor) : the scalar factor for this criterion gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), putting more focus on hard, misclassified examples size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. However, if the field size_average is set to False, the losses are instead summed for each minibatch. """ def __init__(self, class_num, alpha=None, gamma=2, size_average=True): super(FocalLoss, self).__init__() if alpha is None: self.alpha = torch.ones(class_num, 1) else: self.alpha = alpha self.gamma = gamma self.class_num = class_num self.size_average = size_average def forward(self, inputs, targets): N = inputs.size(0) print(N) C = inputs.size(1) P = F.softmax(inputs) class_mask = inputs.data.new(N, C).fill_(0) ids = targets.view(-1, 1) class_mask.scatter_(1, ids.data, 1.) #print(class_mask) if inputs.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() alpha = self.alpha[ids.data.view(-1)] probs = (P*class_mask).sum(1).view(-1,1) log_p = probs.log() #print('probs size= {}'.format(probs.size())) #print(probs) batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p #print('-----bacth_loss------') #print(batch_loss) if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss if __name__ == "__main__": alpha = torch.rand(21, 1) print(alpha) FL = FocalLoss(class_num=5, gamma=0 ) CE = nn.CrossEntropyLoss() N = 4 C = 5 inputs = torch.rand(N, C, requires_grad=True) targets = torch.LongTensor(N).random_(C) inputs_fl = inputs.clone() targets_fl = targets.clone() inputs_ce = inputs.clone() targets_ce = targets.clone() print('----inputs----') print(inputs) print('---target-----') print(targets) fl_loss = FL(inputs_fl, targets_fl) ce_loss = CE(inputs_ce, targets_ce) print('ce = {}, fl ={}'.format(ce_loss.data[0], fl_loss.data[0])) fl_loss.backward() ce_loss.backward() #print(inputs_fl.grad.data) print(inputs_ce.grad.data)