import torch import torch.nn as nn class MatchSegmentation(nn.Module): def __init__(self): super(MatchSegmentation, self).__init__() def forward(self, segmentation, prob, gt_instance, gt_plane_num): """ greedy matching match segmentation with ground truth instance :param segmentation: tensor with size (N, K) :param prob: tensor with size (N, 1) :param gt_instance: tensor with size (21, h, w) :param gt_plane_num: int :return: a (K, 1) long tensor indicate closest ground truth instance id, start from 0 """ n, k = segmentation.size() _, h, w = gt_instance.size() assert (prob.size(0) == n and h*w == n) # ingnore non planar region gt_instance = gt_instance[:gt_plane_num, :, :].view(1, -1, h*w) # (1, gt_plane_num, h*w) segmentation = segmentation.t().view(k, 1, h*w) # (k, 1, h*w) # calculate instance wise cross entropy matrix (K, gt_plane_num) gt_instance = gt_instance.type(torch.float32) ce_loss = - (gt_instance * torch.log(segmentation + 1e-6) + (1-gt_instance) * torch.log(1-segmentation + 1e-6)) # (k, gt_plane_num, k*w) ce_loss = torch.mean(ce_loss, dim=2) # (k, gt_plane_num) matching = torch.argmin(ce_loss, dim=1, keepdim=True) return matching