```##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: chenyuru
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import pi, sqrt

safe_log = lambda x: torch.log(torch.clamp(x, 1e-8, 1e8))

class _BaseKLDivergence(nn.Module):
def __init__(self):
super(_BaseKLDivergence, self).__init__()

def forward(self, Q, P):
"""
Parameters
----------
P: ground truth probability distribution [batch_size, n, n]
Q: predicted probability distribution [batch_size, n, n]

Description
-----------
compute the KL divergence of attention maps. Here P and Q denote
the pixel-level attention map with n spatial positions.
"""
kl_loss = P * safe_log(P / Q)
pixel_loss = torch.sum(kl_loss, dim=-1)
total_loss = torch.mean(pixel_loss)

class AttentionLoss2d(_BaseKLDivergence):
def __init__(self, scale=1):
super(AttentionLoss2d, self).__init__()
self.scale = scale

def get_similarity(self, depth):
b, _, h, w = depth.shape
M = depth.reshape((b, h * w, 1))
N = depth.reshape((b, 1, h * w))
W = F.softmax(-torch.abs(torch.log(M) - torch.log(N)), -1)
W[torch.isnan(W)] = 0
return W

def get_gt_sim_map(self, label):
b, _, h, w = label.shape
res_label = F.interpolate(label, size=(h//8//self.scale, w//8//self.scale), mode='nearest')
gt_sim_map = self.get_similarity(res_label)
return gt_sim_map

def forward(self, sim_map, label):
"""
Parameters
----------
sim_map: [batch_size, n, n]
label: [batch_size, 1, h, w]
"""
gt_sim_map = self.get_gt_sim_map(label)
return super(AttentionLoss2d, self).forward(sim_map, gt_sim_map)

class _BaseEntropyLoss2d(nn.Module):
def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None):
"""
Parameters
----------
ignore_index : Specifies a target value that is ignored
and does not contribute to the input gradient
reduction : Specifies the reduction to apply to the output:
'mean' | 'sum'. 'mean': elemenwise mean,
'sum': class dim will be summed and batch dim will be averaged.
use_weight : whether to use weights of classes.
weight : Tensor, optional
a manual rescaling weight given to each class.
If given, has to be a Tensor of size "nclasses"
"""
super(_BaseEntropyLoss2d, self).__init__()
self.ignore_index = ignore_index
self.reduction = reduction
self.use_weights = use_weights
if use_weights:
print("w/ class balance")
print(weight)
self.weight = torch.FloatTensor(weight).cuda()
else:
print("w/o class balance")
self.weight = None

def get_entropy(self, pred, label):
"""
Return
------
entropy : shape [batch_size, h, w, c]
Description
-----------
Information Entropy based loss need to get the entropy according to your implementation,
each element denotes the loss of a certain position and class.
"""
raise NotImplementedError

def forward(self, pred, label):
"""
Parameters
----------
pred: [batch_size, num_classes, h, w]
label: [batch_size, h, w]
"""
assert pred.dim() == 4
assert label.dim() == 3
assert pred.size(0) == label.size(0), "{0} vs {1} ".format(pred.size(0), label.size(0))
assert pred.size(2) == label.size(1), "{0} vs {1} ".format(pred.size(2), label.size(1))
assert pred.size(3) == label.size(2), "{0} vs {1} ".format(pred.size(3), label.size(3))

n, c, h, w = pred.size()
if self.use_weights:
if self.weight is None:
print('label size {}'.format(label.shape))
freq = np.zeros(c)
for k in range(c):
mask = (label[:, :, :] == k)
print('{}th frequency {}'.format(k, freq[k]))
weight = freq / np.sum(freq) * c
weight = np.median(weight) / weight
self.weight = torch.FloatTensor(weight).cuda()
print('Online class weight: {}'.format(self.weight))
else:
self.weight = 1
if self.ignore_index is None:
self.ignore_index = c + 1

entropy = self.get_entropy(pred, label)

weighted_entropy = entropy * self.weight

if self.reduction == 'sum':
elif self.reduction == 'mean':
return loss

class OrdinalRegression2d(_BaseEntropyLoss2d):
def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None):
super(OrdinalRegression2d, self).__init__(ignore_index, reduction, use_weights, weight)

def get_entropy(self, pred, label):
n, c, h, w = pred.size()
label = label.unsqueeze(3).long()
pred = pred.permute(0, 2, 3, 1)
return -entropy

def NormalDist(x, sigma):
f = torch.exp(-x**2/(2*sigma**2)) / sqrt(2*pi*sigma**2)
return f

class CrossEntropy2d(_BaseEntropyLoss2d):
def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None,
eps=0.0, priorType='uniform'):
"""
Parameters
----------
eps : label smoothing factor
prior : prior distribution, if uniform it is equivalent to the
label smoothing trick (https://arxiv.org/abs/1512.00567).
However, gaussian distribution is more friendly with the depth estimation I think.
"""
super(CrossEntropy2d, self).__init__(ignore_index, reduction, use_weights, weight)
self.eps = eps
self.priorType = priorType

def get_entropy(self, pred, label):
n, c, h, w = pred.size()
label = label.unsqueeze(3).long()
pred = F.softmax(pred, 1).permute(0, 2, 3, 1)
one_hot_label = ((torch.arange(c)).cuda() == label).float()

if self.eps == 0:
prior = 0
else:
if self.priorType == 'gaussian':
tensor = (torch.arange(c).cuda() - label).float()
prior = NormalDist(tensor, c / 10)
elif self.priorType == 'uniform':
prior = 1 / (c-1)

smoothed_label = (1 - self.eps) * one_hot_label + self.eps * prior * (1 - one_hot_label)
entropy = smoothed_label * safe_log(pred) + (1 - smoothed_label) * safe_log(1 - pred)
return -entropy

class OhemCrossEntropy2d(CrossEntropy2d):
def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None,
eps=0.0, priorType='uniform', thresh=0.6, min_kept=0,
):
"""
Parameters
----------
thresh : OHEM (online hard example mining) threshold of correct probability
min_kept : OHEM of minimal kept pixels

Description
-----------
modified from https://github.com/PkuRainBow/OCNet.pytorch/blob/master/utils/loss.py#L68
"""
super(OhemCrossEntropy2d, self).__init__(ignore_index, reduction, use_weights, weight,
eps, priorType)
self.thresh = float(thresh)
self.min_kept = int(min_kept)

def get_ohem_label(self, pred, label):
n, c, h, w = pred.size()
if self.ignore_index is None:
self.ignore_index = c + 1

input_label = label.data.cpu().numpy().ravel().astype(np.int32)
x = np.rollaxis(pred.data.cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0, keepdims=True))
input_prob /= input_prob.sum(axis=0, keepdims=True)

valid_flag = input_label != self.ignore_index
valid_inds = np.where(valid_flag)[0]
valid_label = input_label[valid_flag]
num_valid = valid_flag.sum()
if self.min_kept >= num_valid:
print('Labels: {}'.format(num_valid))
elif num_valid > 0:
valid_prob = input_prob[:,valid_flag]
valid_prob = valid_prob[valid_label, np.arange(len(valid_label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = valid_prob.argsort()
threshold_index = index[ min(len(index), self.min_kept) - 1 ]
if valid_prob[threshold_index] > self.thresh:
threshold = valid_prob[threshold_index]
kept_flag = valid_prob <= threshold
valid_kept_inds = valid_inds[kept_flag]
valid_inds = valid_kept_inds

self.ohem_ratio = len(valid_inds) / num_valid
#print('Max prob: {:.4f}, hard ratio: {:.4f} = {} / {} '.format(input_prob.max(), self.ohem_ratio, len(valid_inds), num_valid))
valid_kept_label = input_label[valid_inds].copy()
input_label.fill(self.ignore_index)
input_label[valid_inds] = valid_kept_label
#valid_flag_new = input_label != self.ignore_index
# print(np.sum(valid_flag_new))
label = torch.from_numpy(input_label.reshape(label.size())).long().cuda()
return label

def get_entropy(self, pred, label):
label = self.get_ohem_label(pred, label)
entropy = super(OhemCrossEntropy2d, self).get_entropy(pred, label)
return entropy
```