import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from itertools import filterfalse class MixLoss(nn.Module): def __init__(self, bce_w=1.0, dice_w=0.0, focal_w=0.0, lovasz_w=0.0, bce_kwargs={}, dice_kwargs={}, focal_kwargs={}, lovasz_kwargs={}): super(MixLoss, self).__init__() self.bce_w = bce_w self.dice_w = dice_w self.focal_w = focal_w self.lovasz_w = lovasz_w self.bce_loss = nn.BCEWithLogitsLoss(**bce_kwargs) self.dice_loss = DiceLoss(**dice_kwargs) self.focal_loss = FocalLoss(**focal_kwargs) self.lovasz_loss = LovaszHinge(**lovasz_kwargs) def forward(self, output, target): loss = 0.0 if self.bce_w: loss += self.bce_w * self.bce_loss(output, target) if self.dice_w: loss += self.dice_w * self.dice_loss(output, target) if self.focal_w: loss += self.focal_w * self.focal_loss(output, target) if self.lovasz_w: loss += self.lovasz_w * self.lovasz_loss(output, target) return loss class DiceLoss(nn.Module): def __init__(self, smooth=1.0, eps=1e-7): super(DiceLoss, self).__init__() self.smooth = smooth self.eps = eps def forward(self, output, target): output = torch.sigmoid(output) if torch.sum(target) == 0: output = 1.0 - output target = 1.0 - target return 1.0 - (2 * torch.sum(output * target) + self.smooth) / ( torch.sum(output) + torch.sum(target) + self.smooth + self.eps) class SoftIoULoss(nn.Module): def __init__(self, n_classes=19): super(SoftIoULoss, self).__init__() self.n_classes = n_classes @staticmethod def to_one_hot(tensor, n_classes): n, h, w = tensor.size() one_hot = torch.zeros(n, n_classes, h, w).scatter_(1, tensor.view(n, 1, h, w), 1) return one_hot def forward(self, logit, target): # logit => N x Classes x H x W # target => N x H x W N = len(logit) pred = F.softmax(logit, dim=1) target_onehot = self.to_one_hot(target, self.n_classes) # Numerator Product inter = pred * target_onehot # Sum over all pixels N x C x H x W => N x C inter = inter.view(N, self.n_classes, -1).sum(2) # Denominator union = pred + target_onehot - (pred * target_onehot) # Sum over all pixels N x C x H x W => N x C union = union.view(N, self.n_classes, -1).sum(2) loss = inter / (union + 1e-16) # Return average loss over classes and batch return -loss.mean() class FocalLoss(nn.Module): def __init__(self, gamma=2, eps=1e-7): super(FocalLoss, self).__init__() self.gamma = gamma self.eps = eps def forward(self, logit, target): prob = torch.sigmoid(logit) prob = prob.clamp(self.eps, 1. - self.eps) loss = -1 * target * torch.log(prob) loss = loss * (1 - logit) ** self.gamma return loss.sum() def lovasz_grad(gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard def isnan(x): return x != x def mean(l, ignore_nan=True, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = filterfalse(isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n def flatten_binary_scores(scores, labels, ignore=None): """ Flattens predictions in the batch (binary case) Remove labels equal to 'ignore' """ scores = scores.view(-1) labels = labels.view(-1) if ignore is None: return scores, labels valid = (labels != ignore) vscores = scores[valid] vlabels = labels[valid] return vscores, vlabels class LovaszHinge(nn.Module): def __init__(self, activation=lambda x: F.elu(x, inplace=True) + 1.0, per_image=True, ignore=None): super(LovaszHinge, self).__init__() self.activation = activation self.per_image = per_image self.ignore = ignore def lovasz_hinge_flat(self, logits, labels): """ Binary Lovasz hinge loss logits: [P] Variable, logits at each prediction (between -\infty and +\infty) labels: [P] Tensor, binary ground truth labels (0 or 1) ignore: label to ignore """ if len(labels) == 0: # only void pixels, the gradients should be 0 return logits.sum() * 0. signs = 2. * labels.float() - 1. errors = (1. - logits * signs) errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] grad = lovasz_grad(gt_sorted) loss = torch.dot(self.activation(errors_sorted), grad) return loss def forward(self, logits, labels): if self.per_image: loss = mean(self.lovasz_hinge_flat( *flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), self.ignore) ) for log, lab in zip(logits, labels)) else: loss = self.lovasz_hinge_flat( *flatten_binary_scores(logits, labels, self.ignore)) return loss def flatten_probas(probas, labels, ignore=None): """ Flattens predictions in the batch """ B, C, H, W = probas.size() probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobas, vlabels def lovasz_softmax_flat(probas, labels, only_present=False): """ Multi-class Lovasz-Softmax loss probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) only_present: average only on classes present in ground truth """ if len(probas) == 0: return np.nan C = probas.size(1) losses = [] for c in range(C): fg = (labels == c).float() # foreground for class c if only_present and fg.sum() == 0: continue errors = (fg - probas[:, c]).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted))) return mean(losses) class LovaszSoftmax(nn.Module): """ Multi-class Lovasz-Softmax loss logits: [B, C, H, W] class logits at each prediction (between 0 and 1) labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) only_present: average only on classes present in ground truth per_image: compute the loss per image instead of per batch ignore: void class labels """ def __init__(self, only_present=False, per_image=True, ignore=None): super(LovaszSoftmax, self).__init__() self.only_present = only_present self.per_image = per_image self.ignore = ignore def forward(self, logits, labels): probas = F.softmax(logits, dim=1) if self.per_image: loss = mean(lovasz_softmax_flat(*flatten_probas( prob.unsqueeze(0), lab.unsqueeze(0), self.ignore), only_present=self.only_present) for prob, lab in zip(probas, labels)) else: loss = lovasz_softmax_flat(*flatten_probas( probas, labels, self.ignore), only_present=self.only_present) return loss # Adapted from OCNet Repository (https://github.com/PkuRainBow/OCNet) class OhemCrossEntropy2d(nn.Module): def __init__(self, ignore_label=255, thresh=0.6, min_kept=0, use_weight=True): super(OhemCrossEntropy2d, self).__init__() self.ignore_label = ignore_label self.thresh = float(thresh) self.min_kept = int(min_kept) if use_weight: print("w/ class balance") weight = torch.FloatTensor( [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) else: print("w/o class balance") self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) def forward(self, predict, target, weight=None): """ Args: predict:(n, c, h, w) target:(n, h, w) weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size "nclasses" """ assert not target.requires_grad assert predict.dim() == 4 assert target.dim() == 3 assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) n, c, h, w = predict.size() input_label = target.data.cpu().numpy().ravel().astype(np.int32) x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1)) input_prob = np.exp(x - x.max(axis=0).reshape((1, -1))) input_prob /= input_prob.sum(axis=0).reshape((1, -1)) valid_flag = input_label != self.ignore_label valid_inds = np.where(valid_flag)[0] label = input_label[valid_flag] num_valid = valid_flag.sum() if self.min_kept >= num_valid: print('Labels: {}'.format(num_valid)) elif num_valid > 0: prob = input_prob[:, valid_flag] pred = prob[label, np.arange(len(label), dtype=np.int32)] threshold = self.thresh if self.min_kept > 0: index = pred.argsort() threshold_index = index[min(len(index), self.min_kept) - 1] if pred[threshold_index] > self.thresh: threshold = pred[threshold_index] kept_flag = pred <= threshold valid_inds = valid_inds[kept_flag] print('hard ratio: {} = {} / {} '.format(round(len(valid_inds)/num_valid, 4), len(valid_inds), num_valid)) label = input_label[valid_inds].copy() input_label.fill(self.ignore_label) input_label[valid_inds] = label print(np.sum(input_label != self.ignore_label)) target = torch.from_numpy(input_label.reshape(target.size())).long().cuda() return self.criterion(predict, target) class CriterionCrossEntropy(nn.Module): def __init__(self, ignore_index=255, weight='lightnet'): super(CriterionCrossEntropy, self).__init__() self.ignore_index = ignore_index if weight == 'lightnet': # https://github.com/ansleliu/LightNet/blob/master/datasets/calculate_class_weight.py self.weight = torch.FloatTensor( [0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437, 5.58540548, 3.56563995, 0.12704978, 1., 0.46783719, 1.34551528, 5.29974114, 0.28342531, 0.9396095, 0.81551811, 0.42679146, 3.6399074, 2.78376194]) else: self.weight = torch.FloatTensor( [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) self.criterion = torch.nn.CrossEntropyLoss(weight=self.weight, ignore_index=ignore_index) def forward(self, preds, target): h, w = target.size(1), target.size(2) scale_pred = F.interpolate(input=preds, size=(h, w), mode='bilinear', align_corners=True) loss = self.criterion(scale_pred, target) return loss class CriterionDSN(nn.Module): def __init__(self, ignore_index=255, use_weight=True, loss_balance_coefs=(0.4, 1.0)): super(CriterionDSN, self).__init__() self.ignore_index = ignore_index self.loss_balance_coefs = loss_balance_coefs weight = torch.FloatTensor( [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) if use_weight: self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) else: self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index) def forward(self, preds, target): h, w = target.size(1), target.size(2) assert len(preds) == len(self.loss_balance_coefs) losses = [] for pred, coef in zip(preds, self.loss_balance_coefs): scale_pred = F.interpolate(input=pred, size=(h, w), mode='bilinear', align_corners=True) losses.append(self.criterion(scale_pred, target) * coef) return sum(losses) class CriterionOhemDSN(nn.Module): """ DSN + OHEM : We need to consider two supervision for the model. """ def __init__(self, ignore_index=255, thres=0.7, min_kept=100000, dsn_weight=0.4, use_weight=True): super(CriterionOhemDSN, self).__init__() self.ignore_index = ignore_index self.dsn_weight = dsn_weight self.criterion = OhemCrossEntropy2d(ignore_index, thres, min_kept, use_weight=use_weight) def forward(self, preds, target): h, w = target.size(1), target.size(2) scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) loss1 = self.criterion(scale_pred, target) scale_pred = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) loss2 = self.criterion(scale_pred, target) return self.dsn_weight * loss1 + loss2 class CriterionOhemDSN_single(nn.Module): """ DSN + OHEM : we find that use hard-mining for both supervision harms the performance. Thus we choose the original loss for the shallow supervision and the hard-mining loss for the deeper supervision """ def __init__(self, ignore_index=255, thres=0.7, min_kept=100000, dsn_weight=0.4): super(CriterionOhemDSN_single, self).__init__() self.ignore_index = ignore_index self.dsn_weight = dsn_weight weight = torch.FloatTensor( [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) self.criterion = torch.nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index) self.criterion_ohem = OhemCrossEntropy2d(ignore_index, thres, min_kept, use_weight=True) def forward(self, preds, target): h, w = target.size(1), target.size(2) scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True) loss1 = self.criterion(scale_pred, target) scale_pred = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True) loss2 = self.criterion_ohem(scale_pred, target) return self.dsn_weight * loss1 + loss2