import torch import numpy as np def bbox_overlaps(boxes, query_boxes): """ Parameters ---------- boxes: (N, 4) ndarray or tensor or variable query_boxes: (K, 4) ndarray or tensor or variable Returns ------- overlaps: (N, K) overlap between boxes and query_boxes """ if isinstance(boxes, np.ndarray): boxes = torch.from_numpy(boxes) query_boxes = torch.from_numpy(query_boxes) out_fn = lambda x: x.numpy() # If input is ndarray, turn the overlaps back to ndarray when return else: out_fn = lambda x: x box_areas = (boxes[:, 2] - boxes[:, 0] + 1) * \ (boxes[:, 3] - boxes[:, 1] + 1) query_areas = (query_boxes[:, 2] - query_boxes[:, 0] + 1) * \ (query_boxes[:, 3] - query_boxes[:, 1] + 1) iw = (torch.min(boxes[:, 2:3], query_boxes[:, 2:3].t()) - torch.max(boxes[:, 0:1], query_boxes[:, 0:1].t()) + 1).clamp(min=0) ih = (torch.min(boxes[:, 3:4], query_boxes[:, 3:4].t()) - torch.max(boxes[:, 1:2], query_boxes[:, 1:2].t()) + 1).clamp(min=0) ua = box_areas.view(-1, 1) + query_areas.view(1, -1) - iw * ih overlaps = iw * ih / ua return out_fn(overlaps)