"""MSCOCO Semantic Segmentation pretraining for VOC.""" import os import pickle from tqdm import trange from PIL import Image import numpy as np from .utils import try_import_pycocotools from ..segbase import SegmentationDataset class COCOSegmentation(SegmentationDataset): """COCO Semantic Segmentation Dataset for VOC Pre-training. Parameters ---------- root : string Path to COCO folder. Default is '$(HOME)/mxnet/datasets/coco' split: string 'train', 'val' or 'test' transform : callable, optional A function that transforms the image Examples -------- >>> from mxnet.gluon.data.vision import transforms >>> # Transforms for Normalization >>> input_transform = transforms.Compose([ >>> transforms.ToTensor(), >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), >>> ]) >>> # Create Dataset >>> trainset = gluoncv.data.COCOSegmentation(split='train', transform=input_transform) >>> # Create Training Loader >>> train_data = gluon.data.DataLoader( >>> trainset, 4, shuffle=True, last_batch='rollover', >>> num_workers=4) """ CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] NUM_CLASS = 21 CLASSES = ("background", "airplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorcycle", "person", "potted-plant", "sheep", "sofa", "train", "tv") def __init__(self, root=os.path.expanduser('~/.mxnet/datasets/coco'), split='train', mode=None, transform=None, **kwargs): super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs) # lazy import pycocotools try_import_pycocotools() from pycocotools.coco import COCO from pycocotools import mask if split == 'train': print('train set') ann_file = os.path.join(root, 'annotations/instances_train2017.json') ids_file = os.path.join(root, 'annotations/train_ids.mx') self.root = os.path.join(root, 'train2017') else: print('val set') ann_file = os.path.join(root, 'annotations/instances_val2017.json') ids_file = os.path.join(root, 'annotations/val_ids.mx') self.root = os.path.join(root, 'val2017') self.coco = COCO(ann_file) self.coco_mask = mask if os.path.exists(ids_file): with open(ids_file, 'rb') as f: self.ids = pickle.load(f) else: ids = list(self.coco.imgs.keys()) self.ids = self._preprocess(ids, ids_file) self.transform = transform def __getitem__(self, index): coco = self.coco img_id = self.ids[index] img_metadata = coco.loadImgs(img_id)[0] path = img_metadata['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) mask = Image.fromarray(self._gen_seg_mask( cocotarget, img_metadata['height'], img_metadata['width'])) # synchrosized transform if self.mode == 'train': img, mask = self._sync_transform(img, mask) elif self.mode == 'val': img, mask = self._val_sync_transform(img, mask) else: assert self.mode == 'testval' img, mask = self._img_transform(img), self._mask_transform(mask) # general resize, normalize and toTensor if self.transform is not None: img = self.transform(img) return img, mask def __len__(self): return len(self.ids) def _gen_seg_mask(self, target, h, w): mask = np.zeros((h, w), dtype=np.uint8) coco_mask = self.coco_mask for instance in target: rle = coco_mask.frPyObjects(instance['segmentation'], h, w) m = coco_mask.decode(rle) cat = instance['category_id'] if cat in self.CAT_LIST: c = self.CAT_LIST.index(cat) else: continue if len(m.shape) < 3: mask[:, :] += (mask == 0) * (m * c) else: mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) return mask def _preprocess(self, ids, ids_file): print("Preprocessing mask, this will take a while." + \ "But don't worry, it only run once for each split.") tbar = trange(len(ids)) new_ids = [] for i in tbar: img_id = ids[i] cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) img_metadata = self.coco.loadImgs(img_id)[0] mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width']) # more than 1k pixels if (mask > 0).sum() > 1000: new_ids.append(img_id) tbar.set_description('Doing: {}/{}, got {} qualified images'.\ format(i, len(ids), len(new_ids))) print('Found number of qualified images: ', len(new_ids)) with open(ids_file, 'wb') as f: pickle.dump(new_ids, f) return new_ids @property def classes(self): """Category names.""" return type(self).CLASSES