#
# focal_loss.py, doom-net
#
# Created by Andrey Kolishchak on 03/11/18.
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from device import device


class FocalLoss(nn.Module):
    def __init__(self, alfa=1, gamma=2):
        super().__init__()
        self.alfa = alfa
        self.gamma = gamma
        self.epsilon = 0 if int(gamma) == gamma and gamma != 0 else 1e-5
        pass

    def forward(self, input, target):
        target = target.view(target.shape[0], 1, *target.shape[1:])
        target_one_hot = torch.zeros(*input.shape, device=device)
        target_one_hot = target_one_hot.scatter_(1, target, 1.0)
        logp = input * target_one_hot
        p = logp.exp()+self.epsilon
        output = -(1-p).pow(self.gamma) * logp
        output = output.sum(dim=1).mean()
        return output


def test():
    loss_nll = nn.NLLLoss2d()
    loss_focal = FocalLoss(gamma=0)
    target = torch.Tensor(2, 1, 5).random_(3).long()

    data = torch.rand(2, 3, 1, 5)
    input1 = torch.Tensor(data, requires_grad=True)
    loss1 = loss_nll(F.log_softmax(input1), target)
    loss1.backward()
    print(loss1)
    print(input1.grad)

    input2 = torch.Tensor(data, requires_grad=True)
    loss2 = loss_focal(F.log_softmax(input2), target)
    loss2.backward()
    print(loss2)
    print(input2.grad)


#test()