import torch import os import random from torch.utils.data import Dataset from torch.utils.data import DataLoader from torchvision import transforms import torchvision.transforms.functional as FN from torchvision.datasets import ImageFolder, MNIST from PIL import Image import json import numpy as np import random from utils import computeIOU, computeContainment, computeUnionArea from pycocotools.coco import COCO as COCOTool from collections import defaultdict from random import shuffle from copy import copy class CocoDatasetBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='datasetBoxAnn.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1, square_resize=0, filter_by_mincooccur = -1., only_indiv_occur = 0., augmenter_mode=0): self.image_path = os.path.join('data','coco','images') self.transform = transform self.mode = mode self.n_boxes = n_boxes self.iouThresh = 0.5 self.dataset = json.load(open(os.path.join('data','coco',datafile),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.out_img_size = out_img_size self.square_resize = square_resize self.bbox_out_size = bbox_out_size self.filter_by_mincooccur = filter_by_mincooccur self.only_indiv_occur = only_indiv_occur #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = select_attrs if len(select_attrs) == 0: self.selected_attrs = [attr['name'] for attr in self.dataset['categories']] self.balance_classes = balance_classes self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.max_with_union= max_with_union self.use_gt_mask = use_gt_mask self.boxrotate = boxrotate self.augmenter_mode = augmenter_mode if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) if use_gt_mask == 1: self.mask_trans =transforms.Compose([transforms.Resize(out_img_size if not square_resize else [out_img_size, out_img_size] , interpolation=Image.NEAREST), transforms.CenterCrop(out_img_size)]) self.mask_provider = CocoMaskDataset(self.mask_trans, mode, select_attrs=self.selected_attrs, balance_classes=balance_classes) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.dataset['images']) def preprocess(self): for i, attr in enumerate(self.dataset['categories']): self.attr2idx[attr['name']] = i self.idx2attr[i] = attr['name'] self.catid2attr[attr['id']] = attr['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] if self.max_object_size > 0.: validImgs = [] for img in self.dataset['images']: if not self.max_with_union: maxSize = max([bb['bbox'][2]*bb['bbox'][3] for bb in img['bboxAnn']]) else: boxByCls = defaultdict(list) for bb in img['bboxAnn']: boxByCls[bb['cid']].append(bb['bbox']) unionAreas = [computeUnionArea(boxes) for cid,boxes in boxByCls.iteritems()] maxSize = max(unionAreas) if maxSize < self.max_object_size: validImgs.append(img) print ' %d of %d images left after size filtering'%(len(validImgs), len(self.dataset['images'])) self.dataset['images'] = validImgs self.valid_ids = [img['cocoid'] for img in self.dataset['images']] self.catsInImg = {} selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): self.dataset['images'][i]['label'] = np.zeros(max(len(selset),1)) self.dataset['images'][i]['bboxAnn'] = [bb for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] # Correct BBox for Resize(of smaller edge) and CenterCrop fixedbbox = [] imgSize = self.dataset['images'][i]['imgSize'] maxSide = np.argmax(imgSize) for j in xrange(len(self.dataset['images'][i]['bboxAnn'])): cbbox = self.dataset['images'][i]['bboxAnn'][j] maxSideLen = int(float(self.out_img_size * imgSize[maxSide]) / (imgSize[1-maxSide])) if not self.square_resize else self.out_img_size assert(maxSideLen >= self.out_img_size) newStartCord = round((maxSideLen - self.out_img_size)/2.) boxStart = min( max(cbbox['bbox'][maxSide]*maxSideLen - newStartCord, 0), self.out_img_size) boxEnd = min(max((cbbox['bbox'][maxSide]+cbbox['bbox'][maxSide+2])*maxSideLen - newStartCord, 0), self.out_img_size) length = boxEnd - boxStart if length >= 1: cbbox['bbox'][maxSide] = float(boxStart)/self.out_img_size cbbox['bbox'][maxSide+2] = float(length)/self.out_img_size if cbbox['bbox'][1-maxSide+2] >= 1./self.out_img_size: fixedbbox.append(cbbox) if cbbox['bbox'][0]<0. or cbbox['bbox'][1] < 0. or cbbox['bbox'][0]>1.0 or cbbox['bbox'][1]> 1.0: import ipdb; ipdb.set_trace() self.dataset['images'][i]['bboxAnn'] = fixedbbox self.dataset['images'][i]['label'][[self.sattr_to_idx[self.catid2attr[bb['cid']]] for bb in img['bboxAnn']]] = 1. # Convert bbox data to numpy arrays #for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): # self.dataset['images'][i]['bboxAnn'][j]['bbox'] = np.array(bb['bbox']) # Create bbox labels. if self.augmenter_mode: lab_in_img = img['label'].nonzero()[0] self.dataset['images'][i]['label_seq'] = lab_in_img n_lab_in_img = len(lab_in_img) #self.dataset['images'][i]['cls_affect'] = np.zeros((n_lab_in_img,n_lab_in_img)) idx2aidx = {l:li for li,l in enumerate(lab_in_img)} boxByCls = defaultdict(list) for bb in img['bboxAnn']: boxByCls[idx2aidx[self.sattr_to_idx[self.catid2attr[bb['cid']]]]].append(bb['bbox']) self.dataset['images'][i]['cls_affect'] = np.array([[min([max([computeContainment(bb1, bb2)[0] for bb1 in boxByCls[li1]]) for bb2 in boxByCls[li2]]) for li2 in xrange(n_lab_in_img)] for li1 in xrange(n_lab_in_img)]) for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): #Check for IOU > 0.5 with other bbox iouAr = [computeContainment(bb['bbox'], bother['bbox'])[0] for bother in self.dataset['images'][i]['bboxAnn']] self.dataset['images'][i]['bboxAnn'][j]['box_label'] = np.zeros(len(selset)) self.dataset['images'][i]['bboxAnn'][j]['box_label'][[self.sattr_to_idx[self.catid2attr[self.dataset['images'][i]['bboxAnn'][ii]['cid']]] for ii,iv in enumerate(iouAr) if iv>self.iouThresh]] = 1. if self.filter_by_mincooccur >= 0. or self.only_indiv_occur: clsToSingleOccur = defaultdict(list) clsCounts = np.zeros(len(self.selected_attrs)) clsIndivCounts = np.zeros(len(self.selected_attrs)) for i, img in enumerate(self.dataset['images']): imgCls = set() for bb in img['bboxAnn']: imgCls.add(self.catid2attr[bb['cid']]) imgCls = list(imgCls) if len(imgCls)==1: clsIndivCounts[self.sattr_to_idx[imgCls[0]]] += 1. clsToSingleOccur[imgCls[0]].append(i) else: clsCounts[[self.sattr_to_idx[cls] for cls in imgCls]] += 1. if self.filter_by_mincooccur >= 0.: n_rem_counts = clsIndivCounts - self.filter_by_mincooccur/(1-self.filter_by_mincooccur) * clsCounts allRemIds = set() for cls in self.selected_attrs: if n_rem_counts[self.sattr_to_idx[cls]] > 0: n_rem_idx = np.arange(len(clsToSingleOccur[cls])) np.random.shuffle(n_rem_idx) n_rem_idx = n_rem_idx[:int(n_rem_counts[self.sattr_to_idx[cls]])] allRemIds.update([clsToSingleOccur[cls][ri] for ri in n_rem_idx]) self.dataset['images'] = [img for i,img in enumerate(self.dataset['images']) if i not in allRemIds] elif self.only_indiv_occur: allKeepIds = set() for cls in self.selected_attrs: allKeepIds.update(clsToSingleOccur[cls]) self.dataset['images'] = [img for i,img in enumerate(self.dataset['images']) if i in allKeepIds] self.valid_ids = [img['cocoid'] for img in self.dataset['images']] print ' %d images left after co_occurence filtering'%(len(self.valid_ids)) self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): classesInImg = [self.catid2attr[bb['cid']] for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] if len(classesInImg): self.catsInImg[i] = classesInImg for att in classesInImg: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.catsInImg[i] = ['bg'] self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def randomBBoxSample(self, index, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.7 maxIou = 0.3 cbboxList = self.dataset['images'][index]['bboxAnn'] if not self.onlyrandBoxes else [] n_t = 0 while 1: if len(cbboxList) and (random.random()<0.9): cbid = random.randrange(len(cbboxList)) sbox = self.dataset['images'][index]['bboxAnn'][cbid] return copy(sbox['bbox']),sbox['box_label'], cbid else: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) # Test for overlap with foreground objects noOverlap = True #if len(cbboxList): for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb['bbox']) if iou > maxIou or aInb >0.8: noOverlap = False if bIna > 0.8: bboxLabel[self.sattr_to_idx[self.catid2attr[bb['cid']]]] = 1 if noOverlap and ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes==1: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) elif self.balance_classes==2: currCls = random.choice(self.catsInImg[index]) if ('person' not in self.catsInImg[index]) or (random.rand()<0.2) else 'person' else: currCls = random.choice(self.catsInImg[index]) cid = [self.sattr_to_idx[currCls]] if currCls != 'bg' else [0] if not self.augmenter_mode: returnvals = self.getbyIndexAndclass(index, cid) else: returnvals = self.getbyIndexAndclassAugmentMode(index) return tuple(returnvals) def getbyIdAndclass(self, imgid, cls, hflip=0): index = self.imgId2idx[imgid] cid = [self.sattr_to_idx[cls]] if cls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIndexAndclass(self, index, cid): image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) currCls = self.selected_attrs[cid[0]] if image.mode != 'RGB': #print image.mode image = image.convert('RGB') sampbbox, bboxLabel, cbid = self.randomBBoxSample(index, 0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.5: bsamp, _, _ = self.randomBBoxSample(index, 0.6-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = self.dataset['images'][index]['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) if self.use_gt_mask==1: # Use GT masks as input gtMask = self.mask_provider.getbyIdAndclass(self.dataset['images'][index]['cocoid'], currCls, hflip=hflip) elif self.use_gt_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif self.use_gt_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] if self.use_gt_mask: mask = torch.cat([mask, gtMask], dim=0) return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def getbyIndexAndclassAugmentMode(self, index): imgData = self.dataset['images'][index] image = Image.open(os.path.join(self.image_path,imgData['filepath'], imgData['filename'])) if image.mode != 'RGB': #print image.mode image = image.convert('RGB') gtBoxes = np.zeros((len(self.selected_attrs),4)) for bbox in imgData['bboxAnn']: gtBoxes[self.sattr_to_idx[self.catid2attr[bbox['cid']]],:] = bbox['bbox'] label = imgData['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) gtBoxes[np.array(label,dtype=np.int),0] = 1.0-(gtBoxes[np.array(label,dtype=np.int),0]+gtBoxes[np.array(label,dtype=np.int),2]) #Get class effect; class_effect = np.zeros((len(self.selected_attrs),len(self.selected_attrs))) class_effect[np.meshgrid(imgData['label_seq'],imgData['label_seq'])] = imgData['cls_affect'] #Convert BBox to actual co-ordinates return self.transform[-1](image), torch.FloatTensor(label), torch.FloatTensor(gtBoxes), torch.LongTensor([imgData['cocoid']]), torch.LongTensor([hflip]), torch.FloatTensor(class_effect.T) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getfilename_bycocoid(self, cocoid): return self.dataset['images'][self.imgId2idx[cocoid]]['filename'] def getcocoid(self, index): return self.dataset['images'][index]['cocoid'] def getGTMaskInp(self, index, cls, hflip=False, mask_type=None): what_mask = self.use_gt_mask if mask_type is None else mask_type if what_mask==1: # Use GT masks as input gtMask = self.mask_provider.getbyIdAndclass(self.dataset['images'][index]['cocoid'], cls, hflip=hflip) elif what_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif what_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. else: gtMask = None return gtMask class ADE20k(Dataset): def __init__(self, transform, split, select_attrs=[], out_img_size=128, bbox_out_size=64, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1, square_resize=0) : self.image_path = os.path.join('data','ade20k') self.transform = transform self.split = split self.n_boxes = n_boxes self.iouThresh = 0.5 datafile = 'train.odgt' if split == 'train' else 'validation.odgt' self.datafile = os.path.join('data','ade20k',datafile) self.dataset = [json.loads(x.rstrip()) for x in open(self.datafile, 'r')] self.num_data = len(self.dataset) clsData = open('data/ade20k/object150_info.csv','r').read().splitlines() self.clsidx2attr = {i:ln.split(',')[-1] for i, ln in enumerate(clsData[1:])} self.clsidx2Stuff = {i:int(ln.split(',')[-2]) for i, ln in enumerate(clsData[1:])} self.validCatIds = set([i for i in self.clsidx2Stuff if not self.clsidx2Stuff[i]]) self.maskSample = 'nonStuff' self.out_img_size = out_img_size self.square_resize = square_resize self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = ['background'] self.max_object_size = max_object_size self.max_with_union= max_with_union self.use_gt_mask = use_gt_mask self.boxrotate = boxrotate if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) if use_gt_mask == 1: self.mask_transform = transforms.Compose([transforms.Resize(out_img_size if not square_resize else [out_img_size, out_img_size] , interpolation=Image.NEAREST), transforms.CenterCrop(out_img_size)]) self.valid_ids = [] for i,img in enumerate(self.dataset): imid = int(os.path.basename(img['fpath_img']).split('.')[0].split('_')[-1]) self.dataset[i]['image_id'] = imid self.valid_ids.append(imid) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') print ('Finished preprocessing dataset..!') self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} def randomBBoxSample(self, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.7 maxIou = 0.3 cbboxList = [] n_t = 0 while 1: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) #if len(cbboxList): if ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly returnvals = self.getbyIndexAndclass(index) return tuple(returnvals) def getbyIdAndclass(self, imgid, cls, hflip=0): index = self.imgId2idx[imgid] cid = [self.sattr_to_idx[cls]] if cls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIndexAndclass(self, index,cls=None): imgDb = self.dataset[index] image_id = imgDb['image_id'] image = Image.open(os.path.join(self.image_path,imgDb['fpath_img'])) if image.mode != 'RGB': #print image.mode image = image.convert('RGB') cid = [0] sampbbox, bboxLabel, cbid = self.randomBBoxSample(0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.5: bsamp, _, _ = self.randomBBoxSample(0.6-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = np.ones(max(len(self.selected_attrs),1)) # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) if self.use_gt_mask==1: # Use GT masks as input gtMask = self.getGTMaskInp(index, hflip=hflip) #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] if self.use_gt_mask: mask = torch.cat([mask, gtMask], dim=0) return self.transform[-1](image), torch.FloatTensor(label), torch.FloatTensor(bboxLabel), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return os.path.basename(self.dataset[index]['fpath_img']) def getfilename_bycocoid(self, cocoid): return os.path.basename(self.dataset[self.imgId2idx[cocoid]]['fpath_img']) def getcocoid(self, index): return self.dataset[index]['image_id'] def getGTMaskInp(self, index, cls=None, hflip=False, mask_type=None): imgDb = self.dataset[index] segmImg = np.array(Image.open(os.path.join(self.image_path,imgDb['fpath_segm'])))-1 presentClass = np.unique(segmImg) validClass = map(lambda x: x in self.validCatIds, presentClass) chosenIdx = np.random.choice(presentClass[validClass]) if np.sum(validClass) > 0 else -10 if chosenIdx < 0: maskTotal = np.zeros((self.out_img_size,self.out_img_size)) sampbbox, bboxLabel, cbid = self.randomBBoxSample(0.5) sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] maskTotal[sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. else: maskTotal = (segmImg == chosenIdx).astype(np.float) if hflip: maskTotal = maskTotal[:,::-1] mask = torch.FloatTensor(np.asarray(self.mask_transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask class BelgaLogoBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='dataset.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1): self.image_path = os.path.join('data','belgalogos','images') self.transform = transform self.mode = mode self.n_boxes = n_boxes self.iouThresh = 0.5 self.dataset = json.load(open(os.path.join('data','belgalogos',datafile),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = select_attrs self.balance_classes = balance_classes self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.max_with_union= max_with_union self.use_gt_mask = use_gt_mask self.boxrotate = boxrotate if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) if use_gt_mask == 1: print ' Not Supported' assert(0) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.dataset['images']) def preprocess(self): for i, attr in enumerate(self.dataset['categories']): self.attr2idx[attr['name']] = i self.idx2attr[i] = attr['name'] self.catid2attr[attr['id']] = attr['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] if self.max_object_size > 0.: validImgs = [] for img in self.dataset['images']: if not self.max_with_union: maxSize = max([bb['bbox'][2]*bb['bbox'][3] for bb in img['bboxAnn']]) else: boxByCls = defaultdict(list) for bb in img['bboxAnn']: boxByCls[bb['cid']].append(bb['bbox']) unionAreas = [computeUnionArea(boxes) for cid,boxes in boxByCls.iteritems()] maxSize = max(unionAreas) if maxSize < self.max_object_size: validImgs.append(img) print ' %d of %d images left after size filtering'%(len(validImgs), len(self.dataset['images'])) self.dataset['images'] = validImgs self.valid_ids = [img['id'] for img in self.dataset['images']] self.catsInImg = {} selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): self.dataset['images'][i]['label'] = np.zeros(max(len(selset),1)) self.dataset['images'][i]['bboxAnn'] = [bb for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] # Correct BBox for Resize(of smaller edge) and CenterCrop fixedbbox = [] imgSize = self.dataset['images'][i]['imgSize'] maxSide = np.argmax(imgSize) for j in xrange(len(self.dataset['images'][i]['bboxAnn'])): cbbox = self.dataset['images'][i]['bboxAnn'][j] maxSideLen = int(float(self.out_img_size * imgSize[maxSide]) / (imgSize[1-maxSide])) assert(maxSideLen >= self.out_img_size) newStartCord = round((maxSideLen - self.out_img_size)/2.) boxStart = min( max(cbbox['bbox'][maxSide]*maxSideLen - newStartCord, 0), self.out_img_size) boxEnd = min(max((cbbox['bbox'][maxSide]+cbbox['bbox'][maxSide+2])*maxSideLen - newStartCord, 0), self.out_img_size) length = boxEnd - boxStart if length > 5: cbbox['bbox'][maxSide] = float(boxStart)/self.out_img_size cbbox['bbox'][maxSide+2] = float(length)/self.out_img_size if cbbox['bbox'][1-maxSide+2] >= 0.04: fixedbbox.append(cbbox) if cbbox['bbox'][0]<0. or cbbox['bbox'][1] < 0. or cbbox['bbox'][0]>1.0 or cbbox['bbox'][1]> 1.0: import ipdb; ipdb.set_trace() self.dataset['images'][i]['bboxAnn'] = fixedbbox self.dataset['images'][i]['label'][[self.sattr_to_idx[self.catid2attr[bb['cid']]] for bb in img['bboxAnn']]] = 1. # Convert bbox data to numpy arrays #for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): # self.dataset['images'][i]['bboxAnn'][j]['bbox'] = np.array(bb['bbox']) # Create bbox labels. for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): #Check for IOU > 0.5 with other bbox iouAr = [computeContainment(bb['bbox'], bother['bbox'])[0] for bother in self.dataset['images'][i]['bboxAnn']] self.dataset['images'][i]['bboxAnn'][j]['box_label'] = np.zeros(len(selset)) self.dataset['images'][i]['bboxAnn'][j]['box_label'][[self.sattr_to_idx[self.catid2attr[self.dataset['images'][i]['bboxAnn'][ii]['cid']]] for ii,iv in enumerate(iouAr) if iv>self.iouThresh]] = 1. self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): classesInImg = [self.catid2attr[bb['cid']] for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] if len(classesInImg): self.catsInImg[i] = classesInImg for att in classesInImg: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.catsInImg[i] = ['bg'] self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def randomBBoxSample(self, index, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.7 maxIou = 0.3 cbboxList = self.dataset['images'][index]['bboxAnn'] if not self.onlyrandBoxes else [] n_t = 0 while 1: if len(cbboxList) and (random.random()<0.9): cbid = random.randrange(len(cbboxList)) sbox = self.dataset['images'][index]['bboxAnn'][cbid] return copy(sbox['bbox']),sbox['box_label'], cbid else: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) # Test for overlap with foreground objects noOverlap = True #if len(cbboxList): for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb['bbox']) if iou > maxIou or aInb >0.8: noOverlap = False if bIna > 0.8: bboxLabel[self.sattr_to_idx[self.catid2attr[bb['cid']]]] = 1 if noOverlap and ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) else: currCls = random.choice(self.catsInImg[index]) cid = [self.sattr_to_idx[currCls]] if currCls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIdAndclass(self, imgid, cls, hflip=0): index = self.imgId2idx[imgid] cid = [self.sattr_to_idx[cls]] if cls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIndexAndclass(self, index, cid): image = Image.open(os.path.join(self.image_path, self.dataset['images'][index]['filename'])) currCls = self.selected_attrs[cid[0]] if image.mode != 'RGB': #print image.mode image = image.convert('RGB') sampbbox, bboxLabel, cbid = self.randomBBoxSample(index, 0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.5: bsamp, _, _ = self.randomBBoxSample(index, 0.6-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = self.dataset['images'][index]['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) if self.use_gt_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif self.use_gt_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] if self.use_gt_mask: mask = torch.cat([mask, gtMask], dim=0) return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getcocoid(self, index): return self.dataset['images'][index]['id'] def getGTMaskInp(self, index, cls, hflip=False, mask_type=None): what_mask = self.use_gt_mask if mask_type is None else mask_type if what_mask==1: print 'not supported' assert(0) elif what_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif what_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. else: gtMask = None return gtMask class UnrelBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='dataset.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1): COCO_classes = ['person' , 'bicycle' , 'car' , 'motorcycle' , 'airplane' , 'bus' , 'train' , 'truck' , 'boat' , 'traffic light' , 'fire hydrant' , 'stop sign' , 'parking meter' , 'bench' , 'bird' , 'cat' , 'dog' , 'horse' , 'sheep' , 'cow' , 'elephant' , 'bear' , 'zebra' , 'giraffe' , 'backpack' , 'umbrella' , 'handbag' , 'tie' , 'suitcase' , 'frisbee' , 'skis' , 'snowboard' , 'sports ball' , 'kite' , 'baseball bat' , 'baseball glove' , 'skateboard' , 'surfboard' , 'tennis racket' , 'bottle' , 'wine glass' , 'cup' , 'fork' , 'knife' , 'spoon' , 'bowl' , 'banana' , 'apple' , 'sandwich' , 'orange' , 'broccoli' , 'carrot' , 'hot dog' , 'pizza' , 'donut' , 'cake' , 'chair' , 'couch' , 'potted plant' , 'bed' , 'dining table' , 'toilet' , 'tv' , 'laptop' , 'mouse' , 'remote' , 'keyboard' , 'cell phone' , 'microwave' , 'oven' , 'toaster' , 'sink' , 'refrigerator' , 'book' , 'clock' , 'vase' , 'scissors' , 'teddy bear' , 'hair drier' , 'toothbrush'] self.use_cococlass = 1 self.image_path = os.path.join('data','unrel','images') self.transform = transform self.mode = mode self.n_boxes = n_boxes self.iouThresh = 0.5 self.dataset = json.load(open(os.path.join('data','unrel',datafile),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = COCO_classes if len(select_attrs) == 0 else select_attrs self.balance_classes = balance_classes self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.max_with_union= max_with_union self.use_gt_mask = 0 self.boxrotate = boxrotate if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) #if use_gt_mask == 1: # print ' Not Supported' # assert(0) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.dataset['images']) def preprocess(self): for i, attr in enumerate(self.dataset['categories']): self.attr2idx[attr['name']] = i self.idx2attr[i] = attr['name'] self.catid2attr[attr['id']] = attr['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] if self.max_object_size > 0.: validImgs = [] for img in self.dataset['images']: if not self.max_with_union: maxSize = max([bb['bbox'][2]*bb['bbox'][3] for bb in img['bboxAnn']]) else: boxByCls = defaultdict(list) for bb in img['bboxAnn']: boxByCls[bb['cid']].append(bb['bbox']) unionAreas = [computeUnionArea(boxes) for cid,boxes in boxByCls.iteritems()] maxSize = max(unionAreas) if maxSize < self.max_object_size: validImgs.append(img) print ' %d of %d images left after size filtering'%(len(validImgs), len(self.dataset['images'])) self.dataset['images'] = validImgs self.valid_ids = [img['id'] for img in self.dataset['images']] self.catsInImg = {} selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): self.dataset['images'][i]['label'] = np.zeros(max(len(selset),1)) self.dataset['images'][i]['bboxAnn'] = [bb for bb in img['bboxAnn'] if bb['cococlass'] in selset] # Correct BBox for Resize(of smaller edge) and CenterCrop fixedbbox = [] imgSize = self.dataset['images'][i]['imgSize'] maxSide = np.argmax(imgSize) for j in xrange(len(self.dataset['images'][i]['bboxAnn'])): cbbox = self.dataset['images'][i]['bboxAnn'][j] maxSideLen = int(float(self.out_img_size * imgSize[maxSide]) / (imgSize[1-maxSide])) assert(maxSideLen >= self.out_img_size) newStartCord = round((maxSideLen - self.out_img_size)/2.) boxStart = min( max(cbbox['bbox'][maxSide]*maxSideLen - newStartCord, 0), self.out_img_size) boxEnd = min(max((cbbox['bbox'][maxSide]+cbbox['bbox'][maxSide+2])*maxSideLen - newStartCord, 0), self.out_img_size) length = boxEnd - boxStart if length > 5: cbbox['bbox'][maxSide] = float(boxStart)/self.out_img_size cbbox['bbox'][maxSide+2] = float(length)/self.out_img_size if cbbox['bbox'][1-maxSide+2] >= 0.04: fixedbbox.append(cbbox) if cbbox['bbox'][0]<0. or cbbox['bbox'][1] < 0. or cbbox['bbox'][0]>1.0 or cbbox['bbox'][1]> 1.0: import ipdb; ipdb.set_trace() self.dataset['images'][i]['bboxAnn'] = fixedbbox self.dataset['images'][i]['label'][[self.sattr_to_idx[bb['cococlass']] for bb in img['bboxAnn']]] = 1. # Convert bbox data to numpy arrays #for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): # self.dataset['images'][i]['bboxAnn'][j]['bbox'] = np.array(bb['bbox']) # Create bbox labels. for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): #Check for IOU > 0.5 with other bbox iouAr = [computeContainment(bb['bbox'], bother['bbox'])[0] for bother in self.dataset['images'][i]['bboxAnn']] self.dataset['images'][i]['bboxAnn'][j]['box_label'] = np.zeros(len(selset)) self.dataset['images'][i]['bboxAnn'][j]['box_label'][[self.sattr_to_idx[self.dataset['images'][i]['bboxAnn'][ii]['cococlass']] for ii,iv in enumerate(iouAr) if iv>self.iouThresh]] = 1. self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): classesInImg = [bb['cococlass'] for bb in img['bboxAnn'] if bb['cococlass'] in selset] if len(classesInImg): self.catsInImg[i] = classesInImg for att in classesInImg: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.catsInImg[i] = ['bg'] self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def randomBBoxSample(self, index, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.85 maxIou = 0.3 cbboxList = self.dataset['images'][index]['bboxAnn'] if not self.onlyrandBoxes else [] n_t = 0 while 1: if len(cbboxList) and (random.random()<0.9): cbid = random.randrange(len(cbboxList)) sbox = self.dataset['images'][index]['bboxAnn'][cbid] return copy(sbox['bbox']),sbox['box_label'], cbid else: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) # Test for overlap with foreground objects noOverlap = True #if len(cbboxList): for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb['bbox']) if iou > maxIou or aInb >0.8: noOverlap = False if bIna > 0.8: bboxLabel[self.sattr_to_idx[bb['cococlass']]] = 1 if noOverlap and ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) else: currCls = random.choice(self.catsInImg[index]) cid = [self.sattr_to_idx[currCls]] if currCls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIdAndclass(self, imgid, cls, hflip=0): index = self.imgId2idx[imgid] cid = [self.sattr_to_idx[cls]] if cls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIndexAndclass(self, index, cid): image = Image.open(os.path.join(self.image_path, self.dataset['images'][index]['filename'])) currCls = self.selected_attrs[cid[0]] if image.mode != 'RGB': #print image.mode image = image.convert('RGB') sampbbox, bboxLabel, cbid = self.randomBBoxSample(index, 0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.7: bsamp, _, _ = self.randomBBoxSample(index, 0.8-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = self.dataset['images'][index]['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) if self.use_gt_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif self.use_gt_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] if self.use_gt_mask: mask = torch.cat([mask, gtMask], dim=0) return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getcocoid(self, index): return self.dataset['images'][index]['id'] def getGTMaskInp(self, index, cls, hflip=False, mask_type=None): what_mask = self.use_gt_mask if mask_type is None else mask_type if what_mask==1: print 'not supported' assert(0) elif what_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif what_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. else: gtMask = None return gtMask class OutofContextBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='dataset.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1): COCO_classes = ['person' , 'bicycle' , 'car' , 'motorcycle' , 'airplane' , 'bus' , 'train' , 'truck' , 'boat' , 'traffic light' , 'fire hydrant' , 'stop sign' , 'parking meter' , 'bench' , 'bird' , 'cat' , 'dog' , 'horse' , 'sheep' , 'cow' , 'elephant' , 'bear' , 'zebra' , 'giraffe' , 'backpack' , 'umbrella' , 'handbag' , 'tie' , 'suitcase' , 'frisbee' , 'skis' , 'snowboard' , 'sports ball' , 'kite' , 'baseball bat' , 'baseball glove' , 'skateboard' , 'surfboard' , 'tennis racket' , 'bottle' , 'wine glass' , 'cup' , 'fork' , 'knife' , 'spoon' , 'bowl' , 'banana' , 'apple' , 'sandwich' , 'orange' , 'broccoli' , 'carrot' , 'hot dog' , 'pizza' , 'donut' , 'cake' , 'chair' , 'couch' , 'potted plant' , 'bed' , 'dining table' , 'toilet' , 'tv' , 'laptop' , 'mouse' , 'remote' , 'keyboard' , 'cell phone' , 'microwave' , 'oven' , 'toaster' , 'sink' , 'refrigerator' , 'book' , 'clock' , 'vase' , 'scissors' , 'teddy bear' , 'hair drier' , 'toothbrush'] self.use_cococlass = 1 self.image_path = os.path.join('data','outofcontext','images') self.transform = transform self.mode = mode self.n_boxes = n_boxes self.iouThresh = 0.5 self.dataset = json.load(open(os.path.join('data','outofcontext',datafile),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = COCO_classes if len(select_attrs) == 0 else select_attrs self.balance_classes = balance_classes self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.max_with_union= max_with_union self.use_gt_mask = 0 self.boxrotate = boxrotate if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) #if use_gt_mask == 1: # print ' Not Supported' # assert(0) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.dataset['images']) def preprocess(self): for i, attr in enumerate(self.dataset['categories']): self.attr2idx[attr['name']] = i self.idx2attr[i] = attr['name'] self.catid2attr[attr['id']] = attr['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] if self.max_object_size > 0.: validImgs = [] for img in self.dataset['images']: if not self.max_with_union: maxSize = max([bb['bbox'][2]*bb['bbox'][3] for bb in img['bboxAnn']]) else: boxByCls = defaultdict(list) for bb in img['bboxAnn']: boxByCls[bb['cid']].append(bb['bbox']) unionAreas = [computeUnionArea(boxes) for cid,boxes in boxByCls.iteritems()] maxSize = max(unionAreas) if maxSize < self.max_object_size: validImgs.append(img) print ' %d of %d images left after size filtering'%(len(validImgs), len(self.dataset['images'])) self.dataset['images'] = validImgs self.valid_ids = [img['id'] for img in self.dataset['images']] self.catsInImg = {} selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): self.dataset['images'][i]['label'] = np.zeros(max(len(selset),1)) self.dataset['images'][i]['bboxAnn'] = [bb for bb in img['bboxAnn'] if bb['cococlass'] in selset]# and bb['outofcontext'] == 1] # Correct BBox for Resize(of smaller edge) and CenterCrop fixedbbox = [] imgSize = self.dataset['images'][i]['imgSize'] maxSide = np.argmax(imgSize) for j in xrange(len(self.dataset['images'][i]['bboxAnn'])): cbbox = self.dataset['images'][i]['bboxAnn'][j] maxSideLen = int(float(self.out_img_size * imgSize[maxSide]) / (imgSize[1-maxSide])) assert(maxSideLen >= self.out_img_size) newStartCord = round((maxSideLen - self.out_img_size)/2.) boxStart = min( max(cbbox['bbox'][maxSide]*maxSideLen - newStartCord, 0), self.out_img_size) boxEnd = min(max((cbbox['bbox'][maxSide]+cbbox['bbox'][maxSide+2])*maxSideLen - newStartCord, 0), self.out_img_size) length = boxEnd - boxStart if length > 5: cbbox['bbox'][maxSide] = float(boxStart)/self.out_img_size cbbox['bbox'][maxSide+2] = float(length)/self.out_img_size if cbbox['bbox'][1-maxSide+2] >= 0.04: fixedbbox.append(cbbox) if cbbox['bbox'][0]<0. or cbbox['bbox'][1] < 0. or cbbox['bbox'][0]>1.0 or cbbox['bbox'][1]> 1.0: import ipdb; ipdb.set_trace() self.dataset['images'][i]['bboxAnn'] = fixedbbox self.dataset['images'][i]['label'][[self.sattr_to_idx[bb['cococlass']] for bb in img['bboxAnn']]] = 1. # Convert bbox data to numpy arrays #for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): # self.dataset['images'][i]['bboxAnn'][j]['bbox'] = np.array(bb['bbox']) # Create bbox labels. for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): #Check for IOU > 0.5 with other bbox iouAr = [computeContainment(bb['bbox'], bother['bbox'])[0] for bother in self.dataset['images'][i]['bboxAnn']] self.dataset['images'][i]['bboxAnn'][j]['box_label'] = np.zeros(len(selset)) self.dataset['images'][i]['bboxAnn'][j]['box_label'][[self.sattr_to_idx[self.dataset['images'][i]['bboxAnn'][ii]['cococlass']] for ii,iv in enumerate(iouAr) if iv>self.iouThresh]] = 1. self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): classesInImg = [bb['cococlass'] for bb in img['bboxAnn'] if bb['cococlass'] in selset] if len(classesInImg): self.catsInImg[i] = classesInImg for att in classesInImg: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.catsInImg[i] = ['bg'] self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def randomBBoxSample(self, index, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.85 maxIou = 0.3 cbboxList = self.dataset['images'][index]['bboxAnn'] if not self.onlyrandBoxes else [] n_t = 0 while 1: if len(cbboxList) and (random.random()<0.9): cbid = random.randrange(len(cbboxList)) sbox = self.dataset['images'][index]['bboxAnn'][cbid] return copy(sbox['bbox']),sbox['box_label'], cbid else: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) # Test for overlap with foreground objects noOverlap = True #if len(cbboxList): for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb['bbox']) if iou > maxIou or aInb >0.8: noOverlap = False if bIna > 0.8: bboxLabel[self.sattr_to_idx[bb['cococlass']]] = 1 if noOverlap and ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) else: currCls = random.choice(self.catsInImg[index]) cid = [self.sattr_to_idx[currCls]] if currCls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIdAndclass(self, imgid, cls, hflip=0): index = self.imgId2idx[imgid] cid = [self.sattr_to_idx[cls]] if cls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIndexAndclass(self, index, cid): image = Image.open(os.path.join(self.image_path, self.dataset['images'][index]['filename'])) currCls = self.selected_attrs[cid[0]] if image.mode != 'RGB': #print image.mode image = image.convert('RGB') sampbbox, bboxLabel, cbid = self.randomBBoxSample(index, 0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.7: bsamp, _, _ = self.randomBBoxSample(index, 0.8-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = self.dataset['images'][index]['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) if self.use_gt_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif self.use_gt_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] if self.use_gt_mask: mask = torch.cat([mask, gtMask], dim=0) return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getcocoid(self, index): return self.dataset['images'][index]['id'] def getGTMaskInp(self, index, cls, hflip=False, mask_type=None): what_mask = self.use_gt_mask if mask_type is None else mask_type if what_mask==1: print 'not supported' assert(0) elif what_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif what_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. else: gtMask = None return gtMask class FlickrLogoBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='dataset.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1): self.image_path = os.path.join('data','flickr_logos_27_dataset','flickr_logos_27_dataset_images') self.transform = transform self.mode = mode self.n_boxes = n_boxes self.iouThresh = 0.5 self.dataset = json.load(open(os.path.join('data','flickr_logos_27_dataset',datafile),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = select_attrs self.balance_classes = balance_classes self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.max_with_union= max_with_union self.use_gt_mask = use_gt_mask self.boxrotate = boxrotate if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) if use_gt_mask == 1: print ' Not Supported' assert(0) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.dataset['images']) def preprocess(self): for i, attr in enumerate(self.dataset['categories']): self.attr2idx[attr['name']] = i self.idx2attr[i] = attr['name'] self.catid2attr[attr['id']] = attr['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] if self.max_object_size > 0.: validImgs = [] for img in self.dataset['images']: if not self.max_with_union: maxSize = max([bb['bbox'][2]*bb['bbox'][3] for bb in img['bboxAnn']]) else: boxByCls = defaultdict(list) for bb in img['bboxAnn']: boxByCls[bb['cid']].append(bb['bbox']) unionAreas = [computeUnionArea(boxes) for cid,boxes in boxByCls.iteritems()] maxSize = max(unionAreas) if maxSize < self.max_object_size: validImgs.append(img) print ' %d of %d images left after size filtering'%(len(validImgs), len(self.dataset['images'])) self.dataset['images'] = validImgs self.valid_ids = [img['id'] for img in self.dataset['images']] self.catsInImg = {} selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): self.dataset['images'][i]['label'] = np.zeros(max(len(selset),1)) self.dataset['images'][i]['bboxAnn'] = [bb for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] # Correct BBox for Resize(of smaller edge) and CenterCrop fixedbbox = [] imgSize = self.dataset['images'][i]['imgSize'] maxSide = np.argmax(imgSize) for j in xrange(len(self.dataset['images'][i]['bboxAnn'])): cbbox = self.dataset['images'][i]['bboxAnn'][j] maxSideLen = int(float(self.out_img_size * imgSize[maxSide]) / (imgSize[1-maxSide])) assert(maxSideLen >= self.out_img_size) newStartCord = round((maxSideLen - self.out_img_size)/2.) boxStart = min( max(cbbox['bbox'][maxSide]*maxSideLen - newStartCord, 0), self.out_img_size) boxEnd = min(max((cbbox['bbox'][maxSide]+cbbox['bbox'][maxSide+2])*maxSideLen - newStartCord, 0), self.out_img_size) length = boxEnd - boxStart if length > 5: cbbox['bbox'][maxSide] = float(boxStart)/self.out_img_size cbbox['bbox'][maxSide+2] = float(length)/self.out_img_size if cbbox['bbox'][1-maxSide+2] >= 0.04: fixedbbox.append(cbbox) if cbbox['bbox'][0]<0. or cbbox['bbox'][1] < 0. or cbbox['bbox'][0]>1.0 or cbbox['bbox'][1]> 1.0: import ipdb; ipdb.set_trace() self.dataset['images'][i]['bboxAnn'] = fixedbbox self.dataset['images'][i]['label'][[self.sattr_to_idx[self.catid2attr[bb['cid']]] for bb in img['bboxAnn']]] = 1. # Convert bbox data to numpy arrays #for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): # self.dataset['images'][i]['bboxAnn'][j]['bbox'] = np.array(bb['bbox']) # Create bbox labels. for j, bb in enumerate(self.dataset['images'][i]['bboxAnn']): #Check for IOU > 0.5 with other bbox iouAr = [computeContainment(bb['bbox'], bother['bbox'])[0] for bother in self.dataset['images'][i]['bboxAnn']] self.dataset['images'][i]['bboxAnn'][j]['box_label'] = np.zeros(len(selset)) self.dataset['images'][i]['bboxAnn'][j]['box_label'][[self.sattr_to_idx[self.catid2attr[self.dataset['images'][i]['bboxAnn'][ii]['cid']]] for ii,iv in enumerate(iouAr) if iv>self.iouThresh]] = 1. self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): classesInImg = [self.catid2attr[bb['cid']] for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] if len(classesInImg): self.catsInImg[i] = classesInImg for att in classesInImg: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.catsInImg[i] = ['bg'] self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def randomBBoxSample(self, index, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.85 maxIou = 0.3 cbboxList = self.dataset['images'][index]['bboxAnn'] if not self.onlyrandBoxes else [] n_t = 0 while 1: if len(cbboxList) and (random.random()<0.9): cbid = random.randrange(len(cbboxList)) sbox = self.dataset['images'][index]['bboxAnn'][cbid] return copy(sbox['bbox']),sbox['box_label'], cbid else: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) # Test for overlap with foreground objects noOverlap = True #if len(cbboxList): for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb['bbox']) if iou > maxIou or aInb >0.8: noOverlap = False if bIna > 0.8: bboxLabel[self.sattr_to_idx[self.catid2attr[bb['cid']]]] = 1 if noOverlap and ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) else: currCls = random.choice(self.catsInImg[index]) cid = [self.sattr_to_idx[currCls]] if currCls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIdAndclass(self, imgid, cls, hflip=0): index = self.imgId2idx[imgid] cid = [self.sattr_to_idx[cls]] if cls != 'bg' else [0] returnvals = self.getbyIndexAndclass(index, cid) return tuple(returnvals) def getbyIndexAndclass(self, index, cid): image = Image.open(os.path.join(self.image_path, self.dataset['images'][index]['filename'])) currCls = self.selected_attrs[cid[0]] if image.mode != 'RGB': #print image.mode image = image.convert('RGB') sampbbox, bboxLabel, cbid = self.randomBBoxSample(index, 0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.7: bsamp, _, _ = self.randomBBoxSample(index, 0.8-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = self.dataset['images'][index]['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) if self.use_gt_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif self.use_gt_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== currCls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] if self.use_gt_mask: mask = torch.cat([mask, gtMask], dim=0) return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getcocoid(self, index): return self.dataset['images'][index]['id'] def getGTMaskInp(self, index, cls, hflip=False, mask_type=None): what_mask = self.use_gt_mask if mask_type is None else mask_type if what_mask==1: print 'not supported' assert(0) elif what_mask==2: # Use GT boxes as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] gtMask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. elif what_mask==3: # Use GT centerpoints as input gtBoxes = [bbox for bbox in self.dataset['images'][index]['bboxAnn'] if self.catid2attr[bbox['cid']]== cls] gtMask = torch.zeros(1,self.out_img_size,self.out_img_size) for box in gtBoxes: bbox = copy(box['bbox']) if hflip: bbox[0] = 1.0-(bbox[0]+bbox[2]) bbox = [int(bc*self.out_img_size) for bc in bbox] cent = [bbox[0] + bbox[2]//2, bbox[1]+bbox[3]//2] # center is marked by a 3x3 square patch gtMask[0,cent[1]-1:cent[1]+2,cent[0]-1:cent[0]+2] = 1. else: gtMask = None return gtMask class Places2DatasetBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='datasetBoxAnn.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., max_with_union=True, use_gt_mask=False, boxrotate=0, n_boxes = 1): self.image_path = os.path.join('data','places2','images') self.transform = transform self.mode = mode self.n_boxes = n_boxes self.iouThresh = 0.5 self.filenames = open(os.path.join('data','places2',mode+'_files.txt'),'r').read().splitlines() self.num_data = len(self.filenames) self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = ['background'] self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.boxrotate = boxrotate if self.boxrotate: self.rotateTrans = transforms.Compose([transforms.RandomRotation(boxrotate,resample=Image.NEAREST)]) self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') print ('Finished preprocessing dataset..!') self.valid_ids = [int(fname.split('_')[-1].split('.')[0][-8:]) for fname in self.filenames] def randomBBoxSample(self, max_area = -1): # With 50% chance sample from background or foreground # Minimum size minLen = 0.1 maxLen = 0.7 maxIou = 0.3 cbboxList = [] n_t = 0 while 1: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) #if len(cbboxList): if ((max_area < 0) or ((sbox[2]*sbox[3])< max_area) or (n_t>5)): return sbox, bboxLabel, cbid n_t += 1 def __getitem__(self, index): # In this situation ignore index and sample classes uniformly image = Image.open(os.path.join(self.image_path,self.filenames[index])) if image.mode != 'RGB': #print image.mode image = image.convert('RGB') cid = [0] sampbbox, bboxLabel, cbid = self.randomBBoxSample(0.5) extra_boxes = [] if self.n_boxes > 1: # Sample random number of boxes between 1 and n_boxes c_nbox = np.random.randint(0,self.n_boxes) c_area = sampbbox[2]*sampbbox[3] for i in xrange(c_nbox): # Also stop at total area > 50% if c_area < 0.5: bsamp, _, _ = self.randomBBoxSample(0.6-c_area) # Extra 10% to make the sampling easier extra_boxes.append(bsamp) c_area += bsamp[2]*bsamp[3] else: break label = np.zeros(max(len(self.selected_attrs),1)) # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping hflip = 0 if self.randHFlip and random.random()>0.5: hflip = 1 image = FN.hflip(image) sampbbox[0] = 1.0-(sampbbox[0]+sampbbox[2]) #Convert BBox to actual co-ordinates sampbbox = [int(bc*self.out_img_size) for bc in sampbbox] #Now obtain the crop boxCrop = FN.resized_crop(image, sampbbox[1], sampbbox[0], sampbbox[3],sampbbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,sampbbox[1]:sampbbox[1]+sampbbox[3],sampbbox[0]:sampbbox[0]+sampbbox[2]] = 1. if self.n_boxes > 1 and len(extra_boxes): for box in extra_boxes: box = [int(bc*self.out_img_size) for bc in box] mask[0,box[1]:box[1]+box[3],box[0]:box[0]+box[2]] = 1. if self.boxrotate: mask = torch.FloatTensor(np.asarray(self.rotateTrans(Image.fromarray(mask.numpy()[0]))))[None,::] return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(sampbbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return self.filenames[index] def getcocoid(self, index): return self.valid_ids[index] def getGTMaskInp(self, index, cls, hflip=False, mask_type=None): gtMask = None return gtMask class PascalDatasetBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='dataset.json', out_img_size=128, bbox_out_size=64, balance_classes=0, onlyrandBoxes=False, max_object_size=0., n_boxes = 1, use_gt_mask=0, boxrotate=0): self.image_path = os.path.join('data','coco','images') self.transform = transform self.mode = mode self.iouThresh = 0.5 self.dataset = json.load(open(os.path.join('data','pascalVoc','dataset.json'),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size #self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.selected_attrs = select_attrs self.balance_classes = balance_classes self.onlyrandBoxes = onlyrandBoxes self.max_object_size = max_object_size self.randHFlip = 'Flip' in transform print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.num_data = len(self.dataset['images']) def preprocess(self): self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] self.valid_ids = [img['filename'].split('.')[0] for img in self.dataset['images']] selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): self.dataset['images'][i]['label'] = np.zeros(max(len(selset),1)) self.dataset['images'][i]['label'][[self.sattr_to_idx[cls] for cls in img['classes']]] = 1. if self.balance_classes: self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): if len(img['classes']): for att in img['classes']: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def randomBBoxSample(self, index): # With 50% chance sample from background or foreground # Minimum size minLen = 0.3 maxLen = 0.8 maxIou = 0.3 cbboxList = [] while 1: if len(cbboxList) and (random.random()<0.9): cbid = random.randrange(len(cbboxList)) sbox = self.dataset['images'][index]['bboxAnn'][cbid] return sbox['bbox'],sbox['box_label'], cbid else: # sample a random background box cbid = None tL_x, tL_y = random.uniform(0,1.-minLen-0.01), random.uniform(0,1.-minLen-0.01) l_x = random.uniform(minLen, min(1.-tL_x,maxLen)) l_y = random.uniform(minLen, min(1.-tL_y,maxLen)) sbox = [tL_x, tL_y, l_x, l_y] # Prepare label for this box bboxLabel = np.zeros(max(len(self.selected_attrs),1)) # Test for overlap with foreground objects noOverlap = True #if len(cbboxList): for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb['bbox']) if iou > maxIou or aInb >0.8: noOverlap = False if bIna > 0.8: bboxLabel[self.sattr_to_idx[self.catid2attr[bb['cid']]]] = 1 if noOverlap: return sbox, bboxLabel, cbid def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) cid = [self.sattr_to_idx[currCls]] if currCls != 'bg' else [0] else: cid = [0] image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) if image.mode != 'RGB': #print image.mode image = image.convert('RGB') bbox, bboxLabel, cbid = self.randomBBoxSample(index) label = self.dataset['images'][index]['label'] # Apply transforms to the image. image = self.transform[0](image) # Now do the flipping if self.randHFlip and random.random()>0.5: image = FN.hflip(image) bbox[0] = 1.0-(bbox[0]+bbox[2]) #Convert BBox to actual co-ordinates bbox = [int(bc*self.out_img_size) for bc in bbox] #print bbox, image.size, cbid #assert bbox[3]>0; #assert bbox[2]>0; #if not ((bbox[0]>=0) and (bbox[0]<128)): # print bbox; # import ipdb;ipdb.set_trace() #assert ((bbox[0]>=0) and (bbox[0]<128)); #assert ((bbox[1]>=0) and (bbox[1]<128)) #Now obtain the crop boxCrop = FN.resized_crop(image, bbox[1], bbox[0], bbox[3],bbox[2], (self.bbox_out_size, self.bbox_out_size)) # Create Mask mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. return self.transform[-1](image), torch.FloatTensor(label), self.transform[-1](boxCrop), torch.FloatTensor(bboxLabel), mask, torch.IntTensor(bbox), torch.LongTensor(cid) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getcocoid(self, index): return self.valid_ids[index] class MNISTDatasetBBoxSample(Dataset): def __init__(self, transform, mode, select_attrs=[], out_img_size=64, bbox_out_size=32, randomrotate=0, scaleRange=[0.1, 0.9], squareAspectRatio=False, use_celeb=False): self.image_path = os.path.join('data','mnist') self.mode = mode self.iouThresh = 0.5 self.maxDigits= 1 self.minDigits = 1 self.use_celeb = use_celeb self.scaleRange = scaleRange self.squareAspectRatio = squareAspectRatio self.nc = 1 if not self.use_celeb else 3 transList = [transforms.RandomHorizontalFlip(), transforms.RandomRotation(randomrotate,resample=Image.BICUBIC)]#, transforms.ColorJitter(0.5,0.5,0.5,0.3) self.digitTransforms = transforms.Compose(transList) self.dataset = MNIST(self.image_path,train=True, transform=self.digitTransforms) if not use_celeb else CelebDataset('./data/celebA/images', './data/celebA/list_attr_celeba.txt', self.digitTransforms, mode) self.num_data = len(self.dataset) self.metadata = {'images':[]} self.catid2attr = {} self.out_img_size = out_img_size self.bbox_out_size = bbox_out_size self.selected_attrs = select_attrs print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') def preprocess(self): for i in xrange(self.num_data): n_objects = np.random.randint(self.minDigits, self.maxDigits+1) c_digits = 0 cbboxList = [] maxIou = 0.1 c = 0 while (len(cbboxList) < n_objects) and (c<10): c+=1 tL_x= random.uniform(0,1.-self.scaleRange[0]-0.01) tL_y = random.uniform(0,1.-self.scaleRange[0]-0.01) l_x = random.uniform(self.scaleRange[0], min(1.-tL_x, self.scaleRange[1])) l_y = random.uniform(self.scaleRange[0], min(1.-tL_y, self.scaleRange[1])) if not self.squareAspectRatio else min(1.-tL_y, l_x) l_x = l_y if self.squareAspectRatio else l_x sbox = [tL_x, tL_y, l_x, l_y] noOverlap = True for bb in cbboxList: iou, aInb, bIna = computeIOU(sbox, bb) if iou > maxIou or aInb>0.8 or bIna>0.8: noOverlap = False break #if bIna > 0.8: # bboxLabel[self.sattr_to_idx[self.catid2attr[bb['cid']]]] = 1 if noOverlap: cbboxList.append(sbox) self.metadata['images'].append(cbboxList) def __getitem__(self, index): # Apply transforms to the image. image = torch.FloatTensor(self.nc,self.out_img_size, self.out_img_size).fill_(-1.) # Get the individual images. randbox = random.randrange(len(self.metadata['images'][index])) imglabel = np.zeros(10, dtype=np.int) boxlabel = np.zeros(10, dtype=np.int) for i,bb in enumerate(self.metadata['images'][index]): imid = random.randrange(self.num_data) bbox = [int(bc*self.out_img_size) for bc in bb] img, label = self.dataset[imid] scImg = FN.resize(img,(bbox[3],bbox[2])) image[:, bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]] = FN.normalize(FN.to_tensor(scImg), mean=(0.5,)*self.nc, std=(0.5,)*self.nc) #imglabel[label] = 1 if i == randbox: outBox = FN.normalize(FN.to_tensor(FN.resize(scImg, (self.bbox_out_size, self.bbox_out_size))), mean=(0.5,)*self.nc, std=(0.5,)*self.nc) mask = torch.zeros(1,self.out_img_size,self.out_img_size) mask[0,bbox[1]:bbox[1]+bbox[3],bbox[0]:bbox[0]+bbox[2]] = 1. outbbox = bbox #boxlabel[label]=1 #return image[[0,0,0],::], torch.FloatTensor([1]), outBox[[0,0,0],::], torch.FloatTensor([1]), mask, torch.IntTensor(outbbox) return image, torch.FloatTensor([1]), outBox, torch.FloatTensor([1]), mask, torch.IntTensor(outbbox) def __len__(self): return self.num_data def getfilename(self, index): return str(index) class CocoDataset(Dataset): def __init__(self, transform, mode, select_attrs=[], datafile='datasetBoxAnn.json', out_img_size = 128, balance_classes=0): self.image_path = os.path.join('data','coco','images') self.transform = transform self.mode = mode self.dataset = json.load(open(os.path.join('data','coco',datafile),'r')) self.num_data = len(self.dataset['images']) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.out_img_size = out_img_size self.balance_classes = balance_classes print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') self.num_data = len(self.dataset['images']) def preprocess(self): for i, attr in enumerate(self.dataset['categories']): self.attr2idx[attr['name']] = i self.idx2attr[i] = attr['name'] self.catid2attr[attr['id']] = attr['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} # First remove unwanted splits: self.dataset['images'] = [img for img in self.dataset['images'] if img['split'] == self.mode] selset = set(self.selected_attrs) for i, img in enumerate(self.dataset['images']): # Correct BBox for Resize(of smaller edge) and CenterCrop fixedbbox = [] imgSize = self.dataset['images'][i]['imgSize'] maxSide = np.argmax(imgSize) for j in xrange(len(self.dataset['images'][i]['bboxAnn'])): cbbox = self.dataset['images'][i]['bboxAnn'][j] maxSideLen = int(float(self.out_img_size * imgSize[maxSide]) / (imgSize[1-maxSide])) assert(maxSideLen >= self.out_img_size) newStartCord = round((maxSideLen - self.out_img_size)/2.) boxStart = min( max(cbbox['bbox'][maxSide]*maxSideLen - newStartCord, 0), self.out_img_size) boxEnd = min(max((cbbox['bbox'][maxSide]+cbbox['bbox'][maxSide+2])*maxSideLen - newStartCord, 0), self.out_img_size) length = boxEnd - boxStart if length > 5: cbbox['bbox'][maxSide] = float(boxStart)/self.out_img_size cbbox['bbox'][maxSide+2] = float(length)/self.out_img_size if cbbox['bbox'][1-maxSide+2] >= 0.04 and ((length*cbbox['bbox'][1-maxSide+2] * self.out_img_size)> 30.): fixedbbox.append(cbbox) if cbbox['bbox'][0]<0. or cbbox['bbox'][1] < 0. or cbbox['bbox'][0]>1.0 or cbbox['bbox'][1]> 1.0: import ipdb; ipdb.set_trace() self.dataset['images'][i]['bboxAnn'] = fixedbbox self.dataset['images'][i]['label'] = np.zeros(len(selset)) self.dataset['images'][i]['label'][[self.sattr_to_idx[self.catid2attr[bb['cid']]] for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset]] = 1. # make a list of image id for each class. if self.balance_classes: self.attToImgId = defaultdict(set) for i, img in enumerate(self.dataset['images']): classesInImg = [self.catid2attr[bb['cid']] for bb in img['bboxAnn'] if self.catid2attr[bb['cid']] in selset] if len(classesInImg): for att in classesInImg: self.attToImgId[att].add(i) else: self.attToImgId['bg'].add(i) self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} def __getitem__(self, index): # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.attToImgId.keys()) index = random.choice(self.attToImgId[currCls]) image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) if image.mode != 'RGB': #print image.mode image = image.convert('RGB') label = self.dataset['images'][index]['label'] return self.transform(image), torch.FloatTensor(label) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getcocoid(self, index): return self.dataset['images'][index]['cocoid'] class CocoMaskDataset(Dataset): def __init__(self, transform, mode, select_attrs=[], balance_classes=0, n_masks_perclass=-1): self.data_path = os.path.join('data','coco') self.transform = transform self.mode = mode filename = 'instances_train2014.json' if mode=='train' else 'instances_val2014.json' self.dataset = COCOTool(os.path.join(self.data_path, filename)) self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs valid_ids = [] for catid in self.dataset.getCatIds(self.selected_attrs): valid_ids.extend(self.dataset.getImgIds(catIds=catid)) self.valid_ids = list(set(valid_ids)) self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.valid_ids) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.nc = 1 self.balance_classes = balance_classes self.n_masks_perclass = n_masks_perclass self.preprocess() print ('Loaded Mask Data') def preprocess(self): for atid in self.dataset.cats: self.catid2attr[self.dataset.cats[atid]['id']] = self.dataset.cats[atid]['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} self.labels = {} self.catsInImg = {} self.validAnnotations = {} self.imgSizes = {} self.validCatIds = self.dataset.getCatIds(self.selected_attrs) selset = set(self.selected_attrs) for i, imgid in enumerate(self.valid_ids): self.labels[i] = np.zeros(len(selset)) self.labels[i][[self.sattr_to_idx[self.catid2attr[ann['category_id']]] for ann in self.dataset.imgToAnns[imgid] if self.catid2attr[ann['category_id']] in selset]] = 1. self.catsInImg[i] = list(set([ann['category_id'] for ann in self.dataset.imgToAnns[imgid] if self.catid2attr[ann['category_id']] in selset])) self.imgSizes[i] = [self.dataset.imgs[imgid]['height'], self.dataset.imgs[imgid]['width']] if self.balance_classes: self.attToImgId = defaultdict(set) for i, imgid in enumerate(self.valid_ids): if len(self.catsInImg[i]): for attid in self.catsInImg[i]: self.attToImgId[self.catid2attr[attid]].add(i) else: import ipdb; ipdb.set_trace() self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} for ann in self.attToImgId: shuffle(self.attToImgId[ann]) if self.n_masks_perclass >0: self.attToImgId = {k:v[:self.n_masks_perclass] for k,v in self.attToImgId.iteritems()} def __getitem__(self, index): #image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.selected_attrs) index = random.choice(self.attToImgId[currCls]) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) label = np.zeros(len(self.selected_attrs)) if len(self.catsInImg[index]): # Randomly sample an annotation currObjId = random.choice(self.catsInImg[index]) for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], currObjId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm label[self.sattr_to_idx[self.catid2attr[currObjId]]] = 1. mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask, torch.FloatTensor(label) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getbyIdAndclass(self, imgid, cls, hflip=0): if (imgid not in self.imgId2idx): maskTotal = np.zeros((128,128)) else: index= self.imgId2idx[imgid] catId = self.dataset.getCatIds(cls) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) if len(self.catsInImg[index]) and (catId[0] in self.catsInImg[index]): # Randomly sample an annotation for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], catId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm if hflip: maskTotal = maskTotal[:,::-1] mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask def getbyClass(self, cls): allMasks = [] for c in cls: curr_obj = self.selected_attrs[c] catId = self.dataset.getCatIds(curr_obj) index = random.choice(self.attToImgId[curr_obj]) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) if len(self.catsInImg[index]): # Randomly sample an annotation for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], catId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] allMasks.append(maskTotal[None,::]) return torch.cat(allMasks,dim=0) def getbyIdAndclassBatch(self, imgid, cls, hFlips = None): allMasks = [] for i,c in enumerate(cls): curr_obj = self.selected_attrs[c] catId = self.dataset.getCatIds(curr_obj) if (imgid[i] not in self.imgId2idx): maskTotal = np.zeros((128,128)) else: index = self.imgId2idx[imgid[i]] maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) if len(self.catsInImg[index]) and (catId[0] in self.catsInImg[index]): # Randomly sample an annotation for ann in self.dataset.loadAnns(self.dataset.getAnnIds(imgid[i], catId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm if (hFlips is not None) and hFlips[i] == 1: maskTotal = maskTotal[:,::-1] maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] allMasks.append(maskTotal[None,::]) return torch.cat(allMasks,dim=0) class SDI_MaskDataset(Dataset): def __init__(self, transform, mode, select_attrs=[], balance_classes=0, n_masks_perclass=-1): self.data_path = os.path.join('data','coco') self.transform = transform self.mode = mode filename = 'instances_train2014.json' if mode=='train' else 'instances_val2014.json' self.dataset = COCOTool(os.path.join(self.data_path, filename)) self.SDI_filelist = open(os.path.join(self.data_path,'SDI_img_list_val.txt'),'r').read().splitlines() self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.valid_ids = [] self.maskToObject = {1:'airplane', 2:'bicycle', 3:'bird', 4:'boat', 5:'bottle', 6:'bus', 7:'car' , 8:'cat', 9:'chair', 10:'cow', 11:'dining table', 12:'dog', 13:'horse', 14:'motorcycle', 15:'person', 16:'potted plant', 17:'sheep', 18:'couch', 19:'train', 20:'tv'} self.objectToMask = {self.maskToObject[v]:v for v in self.maskToObject} self.seg_directory = os.path.join(self.data_path, 'SDI_segmentation') #remove irrelavant results self.imgId2idx = {} for i, fname in enumerate(self.SDI_filelist): cocoid = int(fname.split('.')[0].split('_')[-1]) self.valid_ids.append(cocoid) self.imgId2idx[cocoid] = i self.num_data = len(self.valid_ids) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.nc = 1 self.balance_classes = balance_classes self.n_masks_perclass = n_masks_perclass self.preprocess() print ('Loaded Mask Data') def preprocess(self): for atid in self.dataset.cats: self.catid2attr[self.dataset.cats[atid]['id']] = self.dataset.cats[atid]['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} self.labels = {} self.catsInImg = {} self.validAnnotations = {} self.imgSizes = {} self.validCatIds = self.dataset.getCatIds(self.selected_attrs) #self.imgToAnns = defaultdict(list) #self.imgToCatToAnns = defaultdict(dict) #for i,ann in enumerate(self.mRCNN_results): # self.imgToAnns[ann['image_id']].append(i) # if self.catid2attr[ann['category_id']] not in self.imgToCatToAnns[ann['image_id']]: # self.imgToCatToAnns[ann['image_id']][self.catid2attr[ann['category_id']]] = [] # self.imgToCatToAnns[ann['image_id']][self.catid2attr[ann['category_id']]].append(i) selset = set(self.selected_attrs) for i, imgid in enumerate(self.valid_ids): #self.labels[i] = np.zeros(len(selset)) #self.labels[i][[self.sattr_to_idx[self.catid2attr[ann['category_id']]] for ann in self.dataset.imgToAnns[imgid] if self.catid2attr[ann['category_id']] in selset]] = 1. #self.catsInImg[i] = list(set([ann['category_id'] for ann in self.dataset.imgToAnns[imgid] if self.catid2attr[ann['category_id']] in selset])) self.imgSizes[i] = [self.dataset.imgs[imgid]['width'],self.dataset.imgs[imgid]['height']] #if self.balance_classes: # self.attToImgId = defaultdict(set) # for i, imgid in enumerate(self.valid_ids): # if len(self.catsInImg[i]): # for attid in self.catsInImg[i]: # self.attToImgId[self.catid2attr[attid]].add(i) # else: # import ipdb; ipdb.set_trace() # self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} # for ann in self.attToImgId: # shuffle(self.attToImgId[ann]) # if self.n_masks_perclass >0: # self.attToImgId = {k:v[:self.n_masks_perclass] for k,v in self.attToImgId.iteritems()} def __getitem__(self, index): #image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) # In this situation ignore index and sample classes uniformly assert 0, 'This is not implemented' if self.balance_classes: currCls = random.choice(self.selected_attrs) index = random.choice(self.attToImgId[currCls]) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) label = np.zeros(len(self.selected_attrs)) if len(self.catsInImg[index]): # Randomly sample an annotation currObjId = random.choice(self.catsInImg[index]) for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], currObjId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm label[self.sattr_to_idx[self.catid2attr[currObjId]]] = 1. mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask, torch.FloatTensor(label) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getbyIdAndclass(self, imgid, cls, hflip=0): if (imgid not in self.imgId2idx) or (cls == 'bg'): maskTotal = np.zeros((128,128)) else: index= self.imgId2idx[imgid] catId = self.dataset.getCatIds(cls) segImg = Image.open(os.path.join(self.seg_directory, self.SDI_filelist[index])) oImg_sz = self.imgSizes[index] maxSide = np.argmax(oImg_sz) assert(segImg.size[0] == segImg.size[1]) crop_sizes = [0.,0.] crop_sizes[maxSide] = segImg.size[0] crop_sizes[1-maxSide] = int(oImg_sz[1-maxSide] * (float(segImg.size[0])/float(oImg_sz[maxSide]))) segImg = segImg.crop([0,0,crop_sizes[0],crop_sizes[1]]) maskTotal = (np.array(segImg) == self.objectToMask[cls]).astype(np.float) #if len(self.catsInImg[index]) and (catId[0] in self.catsInImg[index]) and (cls in self.imgToCatToAnns[imgid]): # for annIndex in self.imgToCatToAnns[imgid][cls]: # ann = self.mRCNN_results[annIndex] # cm = self.dataset.annToMask(ann) # maskTotal[:cm.shape[0],:cm.shape[1]] += cm if hflip: maskTotal = maskTotal[:,::-1] mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask def getbyClass(self, cls): assert 0, 'This is not implemented' #allMasks = [] #for c in cls: # curr_obj = self.selected_attrs[c] # catId = self.dataset.getCatIds(curr_obj) # index = random.choice(self.attToImgId[curr_obj]) # maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) # if len(self.catsInImg[index]): # # Randomly sample an annotation # for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], catId)): # cm = self.dataset.annToMask(ann) # maskTotal[:cm.shape[0],:cm.shape[1]] += cm # maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] # allMasks.append(maskTotal[None,::]) #return torch.cat(allMasks,dim=0) def getbyIdAndclassBatch(self, imgid, cls): assert 0, 'This is not implemented' #allMasks = [] #for i,c in enumerate(cls): # curr_obj = self.selected_attrs[c] # catId = self.dataset.getCatIds(curr_obj) # index = self.imgId2idx[imgid[i]] # maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) # if len(self.catsInImg[index]) and (catId in self.catsInImg[index]): # # Randomly sample an annotation # for ann in self.dataset.loadAnns(self.dataset.getAnnIds(imgid[i], catId)): # cm = self.dataset.annToMask(ann) # maskTotal[:cm.shape[0],:cm.shape[1]] += cm # maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] # allMasks.append(maskTotal[None,::]) #return torch.cat(allMasks,dim=0) class MRCNN_MaskDataset(Dataset): def __init__(self, transform, mode, select_attrs=[], balance_classes=0, n_masks_perclass=-1): self.data_path = os.path.join('data','coco') self.transform = transform self.mode = mode filename = 'instances_train2014.json' if mode=='train' else 'instances_val2014.json' self.dataset = COCOTool(os.path.join(self.data_path, filename)) self.mRCNN_results = json.load(open(os.path.join(self.data_path,'maskRCNN_masks_X-101-64x4d-FPN_mAp37p5_coco_2014_minival_results.json'),'r')) self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs valid_ids = [] val_cat_ids = set(self.dataset.getCatIds(self.selected_attrs)) #remove irrelavant results self.mRCNN_results = [ann for ann in self.mRCNN_results if ann['category_id'] in val_cat_ids] for ann in self.mRCNN_results: valid_ids.append(ann['image_id']) self.valid_ids = list(set(valid_ids)) self.imgId2idx = {imid:i for i,imid in enumerate(self.valid_ids)} self.num_data = len(self.valid_ids) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.nc = 1 self.balance_classes = balance_classes self.n_masks_perclass = n_masks_perclass self.preprocess() print ('Loaded Mask Data') def preprocess(self): for atid in self.dataset.cats: self.catid2attr[self.dataset.cats[atid]['id']] = self.dataset.cats[atid]['name'] self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} self.labels = {} self.catsInImg = {} self.validAnnotations = {} self.imgSizes = {} self.validCatIds = self.dataset.getCatIds(self.selected_attrs) self.imgToAnns = defaultdict(list) self.imgToCatToAnns = defaultdict(dict) for i,ann in enumerate(self.mRCNN_results): self.imgToAnns[ann['image_id']].append(i) if self.catid2attr[ann['category_id']] not in self.imgToCatToAnns[ann['image_id']]: self.imgToCatToAnns[ann['image_id']][self.catid2attr[ann['category_id']]] = [] self.imgToCatToAnns[ann['image_id']][self.catid2attr[ann['category_id']]].append(i) selset = set(self.selected_attrs) for i, imgid in enumerate(self.valid_ids): self.labels[i] = np.zeros(len(selset)) self.labels[i][[self.sattr_to_idx[self.catid2attr[ann['category_id']]] for ann in self.dataset.imgToAnns[imgid] if self.catid2attr[ann['category_id']] in selset]] = 1. self.catsInImg[i] = list(set([ann['category_id'] for ann in self.dataset.imgToAnns[imgid] if self.catid2attr[ann['category_id']] in selset])) self.imgSizes[i] = [self.dataset.imgs[imgid]['height'], self.dataset.imgs[imgid]['width']] if self.balance_classes: self.attToImgId = defaultdict(set) for i, imgid in enumerate(self.valid_ids): if len(self.catsInImg[i]): for attid in self.catsInImg[i]: self.attToImgId[self.catid2attr[attid]].add(i) else: import ipdb; ipdb.set_trace() self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} for ann in self.attToImgId: shuffle(self.attToImgId[ann]) if self.n_masks_perclass >0: self.attToImgId = {k:v[:self.n_masks_perclass] for k,v in self.attToImgId.iteritems()} def __getitem__(self, index): #image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.selected_attrs) index = random.choice(self.attToImgId[currCls]) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) label = np.zeros(len(self.selected_attrs)) if len(self.catsInImg[index]): # Randomly sample an annotation currObjId = random.choice(self.catsInImg[index]) for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], currObjId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm label[self.sattr_to_idx[self.catid2attr[currObjId]]] = 1. mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask, torch.FloatTensor(label) def __len__(self): return self.num_data def getfilename(self, index): return self.dataset['images'][index]['filename'] def getbyIdAndclass(self, imgid, cls, hflip=0): if (imgid not in self.imgId2idx) or (cls == 'bg'): maskTotal = np.zeros((128,128)) else: index= self.imgId2idx[imgid] catId = self.dataset.getCatIds(cls) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) if len(self.catsInImg[index]) and (catId[0] in self.catsInImg[index]) and (cls in self.imgToCatToAnns[imgid]): # Randomly sample an annotation for annIndex in self.imgToCatToAnns[imgid][cls]: ann = self.mRCNN_results[annIndex] cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm if hflip: maskTotal = maskTotal[:,::-1] mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask def getbyClass(self, cls): allMasks = [] for c in cls: curr_obj = self.selected_attrs[c] catId = self.dataset.getCatIds(curr_obj) index = random.choice(self.attToImgId[curr_obj]) maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) if len(self.catsInImg[index]): # Randomly sample an annotation for ann in self.dataset.loadAnns(self.dataset.getAnnIds(self.valid_ids[index], catId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] allMasks.append(maskTotal[None,::]) return torch.cat(allMasks,dim=0) def getbyIdAndclassBatch(self, imgid, cls): allMasks = [] for i,c in enumerate(cls): curr_obj = self.selected_attrs[c] catId = self.dataset.getCatIds(curr_obj) index = self.imgId2idx[imgid[i]] maskTotal = np.zeros((self.imgSizes[index][0], self.imgSizes[index][1])) if len(self.catsInImg[index]) and (catId in self.catsInImg[index]): # Randomly sample an annotation for ann in self.dataset.loadAnns(self.dataset.getAnnIds(imgid[i], catId)): cm = self.dataset.annToMask(ann) maskTotal[:cm.shape[0],:cm.shape[1]] += cm maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] allMasks.append(maskTotal[None,::]) return torch.cat(allMasks,dim=0) class PascalMaskDataset(Dataset): def __init__(self, transform, mode, select_attrs=[], balance_classes=0, n_masks_perclass=-1): self.data_path = os.path.join('data','pascalVoc') self.transform = transform self.mode = mode self.dataset = np.load(os.path.join(self.data_path, 'maskdata.npy')).item() self.selected_attrs = ['person', 'book', 'car', 'bird', 'chair'] if select_attrs== [] else select_attrs self.valid_ids = self.dataset.keys() self.num_data = len(self.valid_ids) self.attr2idx = {} self.idx2attr = {} self.catid2attr = {} self.nc = 1 self.balance_classes = 0 #balance_classes self.n_masks_perclass = n_masks_perclass self.preprocess() print ('Loaded Mask Data') def preprocess(self): self.sattr_to_idx = {att:i for i, att in enumerate(self.selected_attrs)} self.labels = {} self.catsInImg = {} self.validAnnotations = {} self.imgSizes = {} self.maskToObject = {1:'airplane', 2:'bicycle', 3:'bird', 4:'boat', 5:'bottle', 6:'bus', 7:'car' , 8:'cat', 9:'chair', 10:'cow', 11:'dining table', 12:'dog', 13:'horse', 14:'motorcycle', 15:'person', 16:'potted plant', 17:'sheep', 18:'couch', 19:'train', 20:'tv'} self.objectToMask = {self.maskToObject[v]:v for v in self.maskToObject} selset = set(self.selected_attrs) for i, imgid in enumerate(self.valid_ids): self.labels[i] = np.zeros(len(selset)) self.labels[i][[self.sattr_to_idx[ann] for ann in self.dataset[imgid]['label'] if ann in selset]] = 1. self.catsInImg[i] = list(set([ann for ann in self.dataset[imgid]['label'] if ann in selset])) self.attToImgId = defaultdict(set) for i, imgid in enumerate(self.valid_ids): if len(self.catsInImg[i]): for att in self.catsInImg[i]: self.attToImgId[att].add(i) #else: # import ipdb; ipdb.set_trace() self.attToImgId = {k:list(v) for k,v in self.attToImgId.iteritems()} for ann in self.attToImgId: shuffle(self.attToImgId[ann]) if self.n_masks_perclass >0: self.attToImgId = {k:v[:self.n_masks_perclass] for k,v in self.attToImgId.iteritems()} def __getitem__(self, index): #image = Image.open(os.path.join(self.image_path,self.dataset['images'][index]['filepath'], self.dataset['images'][index]['filename'])) # In this situation ignore index and sample classes uniformly if self.balance_classes: currCls = random.choice(self.selected_attrs) index = random.choice(self.attToImgId[currCls]) label = np.zeros(len(self.selected_attrs)) if len(self.catsInImg[index]): # Randomly sample an annotation currObj = random.choice(self.catsInImg[index]) maskTotal = (self.dataset[self.valid_ids[index]]['mask'] == self.objectToMask[currObj]).astype(np.float) label[self.sattr_to_idx[currObj]] = 1. else: print 'No obj in %s'%self.valid_ids[index] maskTotal = np.zeros((128, 128)) mask = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return mask, torch.FloatTensor(label) def getbyClass(self, cls): allMasks = [] for c in cls: curr_obj = self.selected_attrs[c] index = random.choice(self.attToImgId[curr_obj]) maskTotal = (self.dataset[self.valid_ids[index]]['mask'] == self.objectToMask[curr_obj]).astype(np.float) maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] allMasks.append(maskTotal[None,::]) return torch.cat(allMasks,dim=0) def getbyIdAndclass(self, imgid, cls): if imgid not in self.valid_ids: print 'Specified coco id not found' return index= self.valid_ids.index(imgid) maskTotal = (self.dataset[self.valid_ids[index]]['mask'] == self.objectToMask[cls]).astype(np.float) maskTotal = torch.FloatTensor(np.asarray(self.transform(Image.fromarray(np.clip(maskTotal,0,1)))))[None,::] return maskTotal def __len__(self): return self.num_data def getfilename(self, index): return self.valid_ids[index] class CelebDataset(Dataset): def __init__(self, image_path, metadata_path, transform, mode, select_attrs=[]): self.image_path = image_path self.transform = transform self.mode = mode self.lines = open(metadata_path, 'r').readlines() self.num_data = int(self.lines[0]) self.attr2idx = {} self.idx2attr = {} self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] if select_attrs== [] else select_attrs print ('Start preprocessing dataset..!') self.preprocess() print ('Finished preprocessing dataset..!') if self.mode == 'train': self.num_data = len(self.train_filenames) elif self.mode == 'test': self.num_data = len(self.test_filenames) def preprocess(self): attrs = self.lines[1].split() for i, attr in enumerate(attrs): self.attr2idx[attr] = i self.idx2attr[i] = attr self.train_filenames = [] self.train_labels = [] self.test_filenames = [] self.test_labels = [] lines = self.lines[2:] random.shuffle(lines) # random shuffling for i, line in enumerate(lines): splits = line.split() filename = splits[0] values = splits[1:] label = [] for idx, value in enumerate(values): attr = self.idx2attr[idx] if attr in self.selected_attrs: if value == '1': label.append(1) else: label.append(0) if (i+1) < 2000: self.test_filenames.append(filename) self.test_labels.append(label) else: self.train_filenames.append(filename) self.train_labels.append(label) def __getitem__(self, index): if self.mode == 'train': image = Image.open(os.path.join(self.image_path, self.train_filenames[index])) label = self.train_labels[index] elif self.mode in ['test']: image = Image.open(os.path.join(self.image_path, self.test_filenames[index])) label = self.test_labels[index] return self.transform(image), torch.FloatTensor(label) def __len__(self): return self.num_data def getfilename(self, index): if self.mode == 'train': return self.train_filenames[index] else: return self.test_filenames[index] def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dataset='CelebA', mode='train', select_attrs=[], datafile='datasetBoxAnn.json', bboxLoader=False, bbox_size = 64, randomrotate=0, randomscale=(0.5, 0.5), loadMasks=False, balance_classes=0, onlyrandBoxes=False, max_object_size=0., n_masks=-1, imagenet_norm=False, use_gt_mask = False, n_boxes = 1, square_resize = 0, filter_by_mincooccur = -1., only_indiv_occur = 0, augmenter_mode = 0): """Build and return data loader.""" transList = [transforms.Resize(image_size if not square_resize else [image_size, image_size]), transforms.CenterCrop(image_size)] if not loadMasks else [transforms.Resize(image_size if not square_resize else [image_size, image_size], interpolation=Image.NEAREST), transforms.RandomCrop(image_size)] if mode == 'train': transList.extend([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) else: transList.extend([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) if imagenet_norm: transList[-1] = transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)) if loadMasks: transform = transforms.Compose(transList[:-2]) elif bboxLoader: # Split the transforms into 3 parts. # First is applied on the entire image before cropping # Second part consists of random augments which needs special handling # second is applied to convert image to tensor applied sperately to image and crop transform = [transforms.Compose(transList[:2]), 'Flip' if mode=='train' else None, transforms.Compose(transList[-2:])] else: transform = transforms.Compose(transList) if loadMasks: if dataset == 'coco': dataset = CocoMaskDataset(transform, mode, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) elif dataset == 'mrcnn': dataset = MRCNN_MaskDataset(transform, mode, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) elif dataset == 'sdi': dataset = SDI_MaskDataset(transform, mode, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) elif dataset == 'pascal': dataset = PascalMaskDataset(transform, mode, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) else: if dataset == 'CelebA': dataset = CelebDataset(image_path, metadata_path, transform, mode, select_attrs=select_attrs) elif dataset == 'RaFD': dataset = ImageFolder(image_path, transform) elif dataset == 'coco': if bboxLoader: dataset = CocoDatasetBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes, square_resize = square_resize, filter_by_mincooccur = filter_by_mincooccur, only_indiv_occur = only_indiv_occur, augmenter_mode = augmenter_mode) else: dataset = CocoDataset(transform, mode, select_attrs=select_attrs, datafile=datafile, out_img_size=image_size, balance_classes=balance_classes) elif dataset == 'places2': dataset = Places2DatasetBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'ade20k': dataset = ADE20k(transform, mode, select_attrs, image_size, bbox_size, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate= randomrotate, n_boxes = n_boxes, square_resize = square_resize) elif dataset == 'flickrlogo': dataset = FlickrLogoBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'belgalogo': dataset = BelgaLogoBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'outofcontext': dataset = OutofContextBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'unrel': dataset = UnrelBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'pascal': if bboxLoader: dataset = PascalDatasetBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'mnist': dataset = MNISTDatasetBBoxSample(transform, mode, select_attrs, image_size, bbox_size, randomrotate=randomrotate, scaleRange=randomscale) elif dataset == 'celebbox': dataset = MNISTDatasetBBoxSample(transform, mode, select_attrs, image_size, bbox_size, randomrotate=randomrotate, scaleRange=randomscale, squareAspectRatio=True, use_celeb=True) shuffle = False if mode == 'train': shuffle = True data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=16 if not loadMasks else 2 if image_size==32 else 6, pin_memory=True) return data_loader def get_dataset(image_path, metadata_path, crop_size, image_size, dataset='CelebA', split='train', select_attrs=[], datafile='datasetBoxAnn.json', bboxLoader=False, bbox_size = 64, randomrotate=0, randomscale=(0.5, 0.5), loadMasks=False, balance_classes=0, onlyrandBoxes=False, max_object_size=0., n_masks=-1, imagenet_norm=False, use_gt_mask = False, mode='test', n_boxes = 1, square_resize = 0, filter_by_mincooccur = -1., only_indiv_occur = 0, augmenter_mode = 0): """Build and return data loader.""" transList = [transforms.Resize(image_size if not square_resize else [image_size, image_size]), transforms.CenterCrop(image_size)] if not loadMasks else [transforms.Resize(image_size if not square_resize else [image_size, image_size], interpolation=Image.NEAREST), transforms.RandomCrop(image_size)] if mode == 'train': transList.extend([ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) else: if loadMasks: transList[-1] = transforms.CenterCrop(image_size) transList.extend([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) if imagenet_norm: transList[-1] = transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225)) if loadMasks: transform = transforms.Compose(transList[:-2]) elif bboxLoader: # Split the transforms into 3 parts. # First is applied on the entire image before cropping # Second part consists of random augments which needs special handling # second is applied to convert image to tensor applied sperately to image and crop transform = [transforms.Compose(transList[:2]), 'Flip' if mode=='train' else None, transforms.Compose(transList[-2:])] else: transform = transforms.Compose(transList) if loadMasks: if dataset == 'coco': dataset = CocoMaskDataset(transform, split, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) elif dataset == 'mrcnn': dataset = MRCNN_MaskDataset(transform, split, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) elif dataset == 'sdi': dataset = SDI_MaskDataset(transform, mode, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) elif dataset == 'pascal': dataset = PascalMaskDataset(transform, split, select_attrs=select_attrs, balance_classes=balance_classes, n_masks_perclass=n_masks) else: if dataset == 'CelebA': dataset = CelebDataset(image_path, metadata_path, transform, split) elif dataset == 'RaFD': dataset = ImageFolder(image_path, transform) elif dataset == 'coco': if bboxLoader: dataset = CocoDatasetBBoxSample(transform, split, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate= randomrotate, n_boxes = n_boxes, square_resize = square_resize, filter_by_mincooccur = filter_by_mincooccur, only_indiv_occur = only_indiv_occur, augmenter_mode = augmenter_mode) else: dataset = CocoDataset(transform, split, select_attrs=select_attrs, datafile=datafile, out_img_size=image_size, balance_classes=balance_classes) elif dataset == 'places2': dataset = Places2DatasetBBoxSample(transform, split, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'ade20k': dataset = ADE20k(transform, split, select_attrs, image_size, bbox_size, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate= randomrotate, n_boxes = n_boxes, square_resize = square_resize) elif dataset == 'flickrlogo': dataset = FlickrLogoBBoxSample(transform, split, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'outofcontext': dataset = OutofContextBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'unrel': dataset = UnrelBBoxSample(transform, mode, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'belgalogo': dataset = BelgaLogoBBoxSample(transform, split, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate = randomrotate, n_boxes = n_boxes) elif dataset == 'pascal': if bboxLoader: dataset = PascalDatasetBBoxSample(transform, split, select_attrs, datafile, image_size, bbox_size, balance_classes=balance_classes, onlyrandBoxes=onlyrandBoxes, max_object_size=max_object_size, use_gt_mask = use_gt_mask, boxrotate= randomrotate, n_boxes = n_boxes) elif dataset == 'mnist': dataset = MNISTDatasetBBoxSample(transform, split, select_attrs, image_size, bbox_size, randomrotate=randomrotate, scaleRange=randomscale) elif dataset == 'celebbox': dataset = MNISTDatasetBBoxSample(transform, split, select_attrs, image_size, bbox_size, randomrotate=randomrotate, scaleRange=randomscale, squareAspectRatio=True, use_celeb=True) return dataset