import cv2 import numpy as np import torch from torch.nn import functional as F from utils.data.structures.boxlist_ops import boxlist_iou from rcnn.utils.matcher import Matcher from rcnn.utils.misc import cat, keep_only_positive_boxes, across_sample from rcnn.core.config import cfg def fast_hist(a, b, n): k = (a >= 0) & (a < n) return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) def cal_one_mean_iou(image_array, label_array, num_parsing): hist = fast_hist(label_array, image_array, num_parsing).astype(np.float) num_cor_pix = np.diag(hist) num_gt_pix = hist.sum(1) union = num_gt_pix + hist.sum(0) - num_cor_pix iu = num_cor_pix / (num_gt_pix + hist.sum(0) - num_cor_pix) return iu def parsing_on_boxes(parsing, rois, heatmap_size): device = rois.device rois = rois.to(torch.device("cpu")) parsing_list = [] for i in range(rois.shape[0]): parsing_ins = parsing[i].cpu().numpy() xmin, ymin, xmax, ymax = torch.round(rois[i]).int() cropped_parsing = parsing_ins[max(0, ymin):ymax, max(0, xmin):xmax] resized_parsing = cv2.resize( cropped_parsing, (heatmap_size[1], heatmap_size[0]), interpolation=cv2.INTER_NEAREST ) parsing_list.append(torch.from_numpy(resized_parsing)) if len(parsing_list) == 0: return torch.empty(0, dtype=torch.int64, device=device) return torch.stack(parsing_list, dim=0).to(device, dtype=torch.int64) def project_parsing_on_boxes(parsing, proposals, resolution): proposals = proposals.convert("xyxy") assert parsing.size == proposals.size, "{}, {}".format(parsing, proposals) return parsing_on_boxes(parsing.parsing, proposals.bbox, resolution) class ParsingRCNNLossComputation(object): def __init__(self, proposal_matcher, resolution): """ Arguments: proposal_matcher (Matcher) resolution (tuple) """ self.proposal_matcher = proposal_matcher self.resolution = resolution self.across_sample = cfg.PRCNN.ACROSS_SAMPLE self.roi_size_per_img = cfg.PRCNN.ROI_SIZE_PER_IMG def match_targets_to_proposals(self, proposal, target): match_quality_matrix = boxlist_iou(target, proposal) matched_idxs = self.proposal_matcher(match_quality_matrix) target = target.copy_with_fields(["labels", "parsing"]) matched_targets = target[matched_idxs.clamp(min=0)] matched_targets.add_field("matched_idxs", matched_idxs) return matched_targets def prepare_targets(self, proposals, targets): all_positive_proposals = [] for proposals_per_image, targets_per_image in zip(proposals, targets): matched_targets = self.match_targets_to_proposals( proposals_per_image, targets_per_image ) matched_idxs = matched_targets.get_field("matched_idxs") labels_per_image = matched_targets.get_field("labels") labels_per_image = labels_per_image.to(dtype=torch.int64) neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD labels_per_image[neg_inds] = 0 # parsing are only computed on positive samples positive_inds = torch.nonzero(labels_per_image > 0).squeeze(1) positive_proposals = proposals_per_image[positive_inds] _parsing = matched_targets.get_field("parsing")[positive_inds] parsing_per_image = project_parsing_on_boxes( _parsing, positive_proposals, self.resolution ) positive_proposals.add_field("parsing_targets", parsing_per_image) all_positive_proposals.append(positive_proposals) return all_positive_proposals def resample(self, proposals, targets): # get all positive proposals (for single image on per GPU) positive_proposals = keep_only_positive_boxes(proposals) # resample for getting targets or matching new IoU positive_proposals = self.prepare_targets(positive_proposals, targets) # apply across-sample strategy (for a batch of images on per GPU) positive_proposals = across_sample( positive_proposals, roi_size_per_img=self.roi_size_per_img, across_sample=self.across_sample ) self.positive_proposals = positive_proposals all_num_positive_proposals = 0 for positive_proposals_per_image in positive_proposals: all_num_positive_proposals += len(positive_proposals_per_image) if all_num_positive_proposals == 0: positive_proposals = [proposals[0][:1]] return positive_proposals def __call__(self, parsing_logits): parsing_targets = [proposals_per_img.get_field("parsing_targets") for proposals_per_img in self.positive_proposals] parsing_targets = cat(parsing_targets, dim=0) if parsing_targets.numel() == 0: return parsing_logits.sum() * 0 parsing_loss = F.cross_entropy( parsing_logits, parsing_targets, reduction="mean" ) parsing_loss *= cfg.PRCNN.LOSS_WEIGHT return parsing_loss def parsing_loss_evaluator(): matcher = Matcher( cfg.FAST_RCNN.FG_IOU_THRESHOLD, cfg.FAST_RCNN.BG_IOU_THRESHOLD, allow_low_quality_matches=False, ) loss_evaluator = ParsingRCNNLossComputation( matcher, cfg.PRCNN.RESOLUTION ) return loss_evaluator