import torch
import torch.nn as nn
import torch.nn.functional as F



class AffinityFieldLoss(nn.Module):
    '''
        loss proposed in the paper: https://arxiv.org/abs/1803.10335
        used for sigmentation tasks
    '''
    def __init__(self, kl_margin, lambda_edge=1., lambda_not_edge=1., ignore_lb=255):
        super(AffinityFieldLoss, self).__init__()
        self.kl_margin = kl_margin
        self.ignore_lb = ignore_lb
        self.lambda_edge = lambda_edge
        self.lambda_not_edge = lambda_not_edge
        self.kldiv = nn.KLDivLoss(reduction='none')

    def forward(self, logits, labels):
        ignore_mask = labels.cpu() == self.ignore_lb
        n_valid = ignore_mask.numel() - ignore_mask.sum().item()
        indices = [
                # center,               # edge
            ((1, None, None, None), (None, -1, None, None)), # up
            ((None, -1, None, None), (1, None, None, None)), # down
            ((None, None, 1, None), (None, None, None, -1)), # left
            ((None, None, None, -1), (None, None, 1, None)), # right
            ((1, None, 1, None), (None, -1, None, -1)), # up-left
            ((1, None, None, -1), (None, -1, 1, None)), # up-right
            ((None, -1, 1, None), (1, None, None, -1)), # down-left
            ((None, -1, None, -1), (1, None, 1, None)), # down-right
        ]

        losses = []
        probs = torch.softmax(logits, dim=1)
        log_probs = torch.log_softmax(logits, dim=1)
        for idx_c, idx_e in indices:
            lbcenter = labels[:, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]].detach()
            lbedge = labels[:, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]].detach()
            igncenter = ignore_mask[:, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]].detach()
            ignedge = ignore_mask[:, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]].detach()
            lgp_center = probs[:, :, idx_c[0]:idx_c[1], idx_c[2]:idx_c[3]]
            lgp_edge = probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]]
            prob_edge = probs[:, :, idx_e[0]:idx_e[1], idx_e[2]:idx_e[3]]
            kldiv = (prob_edge * (lgp_edge - lgp_center)).sum(dim=1)

            kldiv[ignedge | igncenter] = 0
            loss = torch.where(
                lbcenter == lbedge,
                self.lambda_edge * kldiv,
                self.lambda_not_edge * F.relu(self.kl_margin - kldiv, inplace=True)
            ).sum() / n_valid
            losses.append(loss)

        return sum(losses) / 8



if __name__ == '__main__':
    criteria = AffinityFieldLoss(kl_margin=3.)
    criteria.cuda()
    logits = torch.randn(4, 19, 768, 768).cuda()
    logits = torch.tensor(logits, requires_grad=True)
    scores = torch.softmax(logits, dim=1)
    labels = torch.randint(0, 19, (4, 768, 768)).cuda()
    labels[0, 30:35, 40:45] = 255
    labels[1, 0:5, 40:45] = 255

    loss = criteria(scores, labels)
    print(loss)
    loss.backward()