import os import os.path import sys import torch import torch.utils.data as data import cv2 import numpy as np if sys.version_info[0] == 2: import xml.etree.cElementTree as ET else: import xml.etree.ElementTree as ET WIDER_CLASSES = ('__background__', 'hand') class AnnotationTransform(object): """Transforms a VOC annotation into a Tensor of bbox coords and label index Initilized with a dictionary lookup of classnames to indexes Arguments: class_to_ind (dict, optional): dictionary lookup of classnames -> indexes (default: alphabetic indexing of VOC's 20 classes) keep_difficult (bool, optional): keep difficult instances or not (default: False) height (int): height width (int): width """ def __init__(self, class_to_ind=None, keep_difficult=True): self.class_to_ind = class_to_ind or dict( zip(WIDER_CLASSES, range(len(WIDER_CLASSES)))) self.keep_difficult = keep_difficult def __call__(self, target): """ Arguments: target (annotation) : the target annotation to be made usable will be an ET.Element Returns: a list containing lists of bounding boxes [bbox coords, class name] """ res = np.empty((0, 5)) for obj in target.iter('object'): difficult = int(obj.find('difficult').text) == 1 if not self.keep_difficult and difficult: continue name = obj.find('name').text.lower().strip() bbox = obj.find('bndbox') pts = ['xmin', 'ymin', 'xmax', 'ymax'] bndbox = [] for i, pt in enumerate(pts): cur_pt = int(bbox.find(pt).text) bndbox.append(cur_pt) label_idx = self.class_to_ind[name] bndbox.append(label_idx) res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind] return res class VOCDetection(data.Dataset): """VOC Detection Dataset Object input is image, target is annotation Arguments: root (string): filepath to WIDER folder target_transform (callable, optional): transformation to perform on the target `annotation` (eg: take in caption string, return tensor of word indices) """ def __init__(self, root, preproc=None, target_transform=None): self.root = root self.preproc = preproc self.target_transform = target_transform self._annopath = os.path.join(self.root, 'Annotations', '%s.xml') self._imgpath = os.path.join(self.root, 'images', '%s.jpg') self.ids = list() with open(os.path.join(self.root, 'ImageSets/Main/trainval.txt'), 'r') as f: self.ids = [tuple(line.split()) for line in f] def __getitem__(self, index): img_id = self.ids[index] target = ET.parse(self._annopath % img_id[0]).getroot() img = cv2.imread(self._imgpath % img_id[0], cv2.IMREAD_COLOR) height, width, _ = img.shape if self.target_transform is not None: target = self.target_transform(target) if self.preproc is not None: img, target = self.preproc(img, target) return torch.from_numpy(img), target def __len__(self): return len(self.ids) def detection_collate(batch): """Custom collate fn for dealing with batches of images that have a different number of associated object annotations (bounding boxes). Arguments: batch: (tuple) A tuple of tensor images and lists of annotations Return: A tuple containing: 1) (tensor) batch of images stacked on their 0 dim 2) (list of tensors) annotations for a given image are stacked on 0 dim """ targets = [] imgs = [] for _, sample in enumerate(batch): for _, tup in enumerate(sample): if torch.is_tensor(tup): imgs.append(tup) elif isinstance(tup, type(np.empty(0))): annos = torch.from_numpy(tup).float() targets.append(annos) return (torch.stack(imgs, 0), targets)