import os import json import pdb import numpy as np from collections import defaultdict class Dataset(object): ''' Base class for a dataset. To be overloaded. Contains: - images --> get_image(i) --> image - image labels --> get_label(i) - list of image queries --> get_query(i) --> image - list of query ROIs --> get_query_roi(i) Creation: Use dataset.create( "..." ) to instanciate one. db = dataset.create( "ImageList('path/to/list.txt')" ) Attributes: root: image directory root nimg: number of images == len(self) nclass: number of classes ''' root = '' img_dir = '' nimg = 0 nclass = 0 ninstance = 0 classes = [] # all class names (len == nclass) labels = [] # all image labels (len == nimg) c_relevant_idx = {} # images belonging to each class (c_relevant_idx[cl_name] = [idx list]) def __len__(self): return self.nimg def get_filename(self, img_idx, root=None): return os.path.join(root or self.root, self.img_dir, self.get_key(img_idx)) def get_key(self, img_idx): raise NotImplementedError() def key_to_index(self, key): if not hasattr(self, '_key_to_index'): self._key_to_index = {self.get_key(i):i for i in range(len(self))} return self._key_to_index[key] def get_image(self, img_idx, resize=None): from PIL import Image img = Image.open(self.get_filename(img_idx)).convert('RGB') if resize: img = img.resize(resize, Image.ANTIALIAS if np.prod(resize) < np.prod(img.size) else Image.BICUBIC) return img def get_image_size(self, img_idx): return self.imsize def get_label(self, img_idx, toint=False): raise NotImplementedError() def has_label(self): try: self.get_label(0); return True except NotImplementedError: return False def get_query_db(self): raise NotImplementedError() def get_query_groundtruth(self, query_idx, what='AP'): query_db = self.get_query_db() assert self.nclass == query_db.nclass if what == 'AP': res = -np.ones(self.nimg, dtype=np.int8) # all negatives res[self.c_relevant_idx[query_db.get_label(query_idx)]] = 1 # positives if query_db == self: res[query_idx] = 0 # query is junk elif what == 'label': res = query_db.get_label(query_idx) else: raise ValueError("Unknown ground-truth type: %s" % what) return res def eval_query_AP(self, query_idx, scores): """ Evaluates AP for a given query. """ from ..utils.evaluation import compute_AP gt = self.get_query_groundtruth(query_idx, 'AP') # labels in {-1, 0, 1} assert gt.shape == scores.shape, "scores should have shape %s" % str(gt.shape) assert -1 <= gt.min() and gt.max() <= 1, "bad ground-truth labels" keep = (gt != 0) # remove null labels if sum(gt[keep]>0) == 0: return -1 # exclude queries with no relevants form the evaluation return compute_AP(gt[keep]>0, scores[keep]) def eval_query_top(self, query_idx, scores, k=(1,5,10,20,50,100)): """ Evaluates top-k for a given query. """ if not self.labels: raise NotImplementedError() q_label = self.get_query_groundtruth(query_idx, 'label') correct = np.bool8([l==q_label for l in self.labels]) correct = correct[(-scores).argsort()] return {k_:float(correct[:k_].any()) for k_ in k if k_<len(correct)} def original(self): return self # overload this when the dataset is derived from another one def __repr__(self): res = 'Dataset: %s\n' % self.__class__.__name__ res += ' %d images' % len(self) if self.nclass: res += ", %d classes" % (self.nclass) if self.ninstance: res += ', %d instances' % (self.ninstance) try: res += ', %d queries' % (self.get_query_db().nimg) except NotImplementedError: pass res += '\n root: %s...' % self.root return res def split( dataset, train_prop, val_prop=0, method='balanced' ): ''' Split a dataset into several subset: train, val and test method = hash: Split are reliable, i.e. unaffected by adding/removing images. But some clusters might be uneven (not respecting props well) method = balanced: splits are balanced (they respect props well), but not stable to modifications of the dataset. Returns: (train, val, test) if val_prop==0: return (train, test) ''' assert 0 <= train_prop <= 1 assert 0 <= val_prop < 1 assert train_prop + val_prop <= 1 train = [] val = [] test = [] # redefine hash(), because built-in is not session-consistent anymore import hashlib hash = lambda x: int(hashlib.md5(bytes(x,"ascii")).hexdigest(),16) if method == 'balanced': test_prop = 1 - train_prop - val_prop perclass = [[] for i in range(dataset.nclass)] for i in range(len(dataset)): label = dataset.get_label(i, toint=True) h = hash(dataset.get_key(i)) perclass[label].append( (h,i) ) for imgs in perclass: nn = len(imgs) imgs.sort() # randomize order consistently with hash if imgs: imgs = list(list(zip(*imgs))[1]) # discard hash if imgs and train_prop > 0: train.append( imgs.pop() ) # ensure at least 1 training sample for i in range(int(0.9999+val_prop*nn)): if imgs: val.append( imgs.pop() ) for i in range(int(0.9999+test_prop*nn)): if imgs: test.append( imgs.pop() ) if imgs: train += imgs train.sort() val.sort() test.sort() elif method == 'hash': val_prop2 = train_prop + val_prop for i in range(len(dataset)): fname = dataset.get_key(i) # compute file hash h = (hash(fname)%100)/100.0 if h < train_prop: train.append( i ) elif h < val_prop2: val.append( i ) else: test.append( i ) else: raise ValueError("bad split method "+method) train = SubDataset(dataset, train) val = SubDataset(dataset, val) test = SubDataset(dataset, test) if val_prop == 0: return train, test else: return train, val, test class SubDataset(Dataset): ''' Contains a sub-part of another dataset. ''' def __init__(self, dataset, indices): self.root = dataset.root self.img_dir = dataset.img_dir self.dataset = dataset self.indices = indices self.nimg = len(self.indices) self.nclass = self.dataset.nclass def get_key(self, i): return self.dataset.get_key(self.indices[i]) def get_label(self, i, **kw): return self.dataset.get_label(self.indices[i],**kw) def get_bbox(self, i, **kw): if hasattr(self.dataset,'get_bbox'): return self.dataset.get_bbox(self.indices[i],**kw) else: raise NotImplementedError() def __repr__(self): res = 'SubDataset(%s)\n' % self.dataset.__class__.__name__ res += ' %d/%d images, %d classes\n' % (len(self),len(self.dataset),self.nclass) res += ' root: %s...' % os.path.join(self.root,self.img_dir) return res def viz_distr(self): from matplotlib import pyplot as pl; pl.ion() count = [0]*self.nclass for i in range(self.nimg): count[ self.get_label(i,toint=True) ] += 1 cid = list(range(self.nclass)) pl.bar(cid, count) pdb.set_trace() class CatDataset(Dataset): ''' Concatenation of several datasets. ''' def __init__(self, *datasets): assert len(datasets) >= 1 self.datasets = datasets db = datasets[0] self.root = os.path.normpath(os.path.join(db.root, db.img_dir)) + os.sep self.labels = self.imgs = None # cannot access it the normal way self.classes = db.classes self.nclass = db.nclass self.c_relevant_idx = defaultdict(list) offsets = [0] full_root = lambda db: os.path.normpath(os.path.join(db.root, db.img_dir)) for db in datasets: assert db.nclass == self.nclass, 'All dataset must have the same number of classes' assert db.classes == self.classes, 'All datasets must have the same classes' # look for a common root self.root = os.path.commonprefix((self.root, full_root(db) + os.sep)) assert self.root, 'no common root between datasets' self.root = self.root[:self.root.rfind(os.sep)] + os.sep offset = sum(offsets) for label, rel in db.c_relevant_idx.items(): self.c_relevant_idx[label] += [i+offset for i in rel] offsets.append(db.nimg) self.roots = [full_root(db)[len(self.root):] for db in datasets] self.offsets = np.cumsum(offsets) self.nimg = self.offsets[-1] def which(self, i): pos = np.searchsorted(self.offsets, i, side='right')-1 assert pos < self.nimg, 'Bad image index %d >= %d' % (i, self.nimg) return pos, i - self.offsets[pos] def get(self, i, attr): b, j = self.which(i) return getattr(self.datasets[b],attr) def __getattr__(self, name): # try getting it val = getattr(self.datasets[0], name) assert not callable(val), 'CatDataset: %s is not a shared attribute, use call()' % str(name) for db in self.datasets[1:]: assert np.all(val == getattr(db, name)), 'CatDataset: inconsistent shared attribute %s, use get()' % str(name) return val def call(self, i, func, *args, **kwargs): b, j = self.which(i) return getattr(self.datasets[b],attr)(j,*args, **kwargs) def get_key(self, i): b, i = self.which(i) key = self.datasets[b].get_key(i) return os.path.join(self.roots[b], key) def get_label(self, i, toint=False): b, i = self.which(i) return self.datasets[b].get_label(i,toint=toint) def get_bbox(self,i): b, i = self.which(i) return self.datasets[b].get_bbox(i) def get_polygons(self,i,**kw): b, i = self.which(i) return self.datasets[b].get_polygons(i,**kw) def deploy( dataset, target_dir, transforms=None, redo=False, ext=None, **savekwargs): if not target_dir: return dataset from PIL import Image from fcntl import flock, LOCK_EX import tqdm if transforms is not None: # identify transform with a unique hash import hashlib def get_params(trf): if type(trf).__name__ == 'Compose': return [get_params(t) for t in trf.transforms] else: return {type(trf).__name__:vars(trf)} params = get_params(transforms) unique_key = json.dumps(params, sort_keys=True).encode('utf-8') h = hashlib.md5().hexdigest() target_dir = os.path.join(target_dir, h) print("Deploying in '%s'" % target_dir) try: imsizes_path = os.path.join(target_dir,'imsizes.json') imsize_file = open(imsizes_path,'r+') #print("opening %s in r+ mode"%imsize_file) except IOError: try: os.makedirs(os.path.split(imsizes_path)[0]) except OSError: pass imsize_file = open(imsizes_path,'w+') #print("opening %s in w+ mode"%imsize_file) # block access to this file, only one process can continue from time import time as now t0 = now() flock(imsize_file, LOCK_EX) #print("exclusive access lock for %s after %ds"%(imsize_file,now()-t0)) try: imsizes = json.load(imsize_file) imsizes = {img:tuple(size) for img,size in imsizes.items()} except: imsizes = {} def check_one_image(i): key = dataset.get_key(i) target = os.path.join(target_dir, key) if ext: target = os.path.splitext(target)[0]+'.'+ext updated = 0 if redo or (not os.path.isfile(target)) or key not in imsizes: # load image and transform it img = Image.open(dataset.get_filename(i)).convert('RGB') imsizes[key] = img.size if transforms is not None: img = transforms(img) odir = os.path.split( target )[0] try: os.makedirs(odir) except FileExistsError: pass img.save( target, **savekwargs ) updated = 1 if (i % 100) == 0: imsize_file.seek(0) # goto begining json.dump(dict(imsizes), imsize_file) imsize_file.truncate() updated = 0 return updated from nltools.gutils import job_utils for i in range(len(dataset)): updated = check_one_image(i) # first try without any threads if updated: break if i+1 < len(dataset): updated += sum(job_utils.parallel_threads(range(i+1,len(dataset)), check_one_image, desc='Deploying dataset', n_threads=0, front_num=0)) if updated: imsize_file.seek(0) # goto begining json.dump(dict(imsizes), imsize_file) imsize_file.truncate() imsize_file.close() # now, other processes can access too return DeployedDataset(dataset, target_dir, imsizes, trfs=transforms, ext=ext) class DeployedDataset(Dataset): '''Just a deployed dataset with a different root and image extension. ''' def __init__(self, dataset, root, imsizes=None, trfs=None, ext=None): self.dataset = dataset if root[-1] != '/': root += '/' self.root = root self.ext = ext self.imsizes = imsizes or json.load(open(root+'imsizes.json')) self.trfs = trfs or (lambda x: x) assert isinstance(self.imsizes, dict) assert len(self.imsizes) >= dataset.nimg, pdb.set_trace() self.nimg = dataset.nimg self.nclass = dataset.nclass self.labels = dataset.labels self.c_relevant_idx = dataset.c_relevant_idx #self.c_non_relevant_idx = dataset.c_non_relevant_idx self.get_label = dataset.get_label self.classes = dataset.classes if '/query_db/' not in root: try: query_db = dataset.get_query_db() if query_db is not dataset: self.query_db = deploy(query_db, os.path.join(root,'query_db'), transforms=trfs, ext=ext) self.get_query_db = lambda: self.query_db except NotImplementedError: pass self.get_query_groundtruth = dataset.get_query_groundtruth if hasattr(dataset, 'eval_query_AP'): self.eval_query_AP = dataset.eval_query_AP if hasattr(dataset, 'true_pairs'): self.true_pairs = dataset.true_pairs self.get_false_pairs = dataset.get_false_pairs def __repr__(self): res = self.dataset.__repr__() res += ' deployed at %s/...%s' % (self.root, self.ext or '') return res def __len__(self): return self.nimg def get_key(self, i): key = self.dataset.get_key(i) if self.ext: key = os.path.splitext(key)[0]+'.'+self.ext return key def get_something(self, what, i, *args, **fmt): try: get_func = getattr(self.dataset, 'get_'+what) except AttributeError: raise NotImplementedError() imsize = self.imsizes[self.dataset.get_key(i)] sth = get_func(i,*args,**fmt) return self.trfs({'imsize':imsize, what:sth})[what] def get_bbox(self, i, **kw): return self.get_something('bbox', i, **kw) def get_polygons(self, i, *args, **kw): return self.get_something('polygons', i, *args, **kw) def get_label_map(self, i, *args, **kw): assert 'polygons' in kw, "you need to supply polygons because image has been transformed" return self.dataset.get_label_map(i, *args, **kw) def get_instance_map(self, i, *args, **kw): assert 'polygons' in kw, "you need to supply polygons because image has been transformed" return self.dataset.get_instance_map(i, *args, **kw) def get_angle_map(self, i, *args, **kw): assert 'polygons' in kw, "you need to supply polygons because image has been transformed" return self.dataset.get_angle_map(i, *args, **kw) def original(self): return self.dataset def deploy_and_split( trainset, deploy_trf=None, deploy_dir='/dev/shm', valset=None, split_val=0.0, img_ext='jpg', img_quality=95, **_useless ): ''' Deploy and split a dataset into train / val. if valset is not provided, then trainset is automatically split into train/val based on the split_val proportion. ''' # first, deploy the training set traindb = deploy( trainset, deploy_dir, transforms=deploy_trf, ext=img_ext, quality=img_quality ) if valset: # load a validation db valdb = deploy( valset, deploy_dir, transforms=deploy_trf, ext=img_ext, quality=img_quality ) else: if split_val > 0: # automatic split in train/val traindb, valdb = split( traindb, train_prop=1-split_val ) else: valdb = None print( "\n>> Training set:" ); print( traindb ) print( "\n>> Validation set:" ); print( valdb ) return traindb, valdb class CropDataset(Dataset): """list_of_imgs_and_crops = [(img_key, (l, t, r, b)), ...] """ def __init__(self, dataset, list_of_imgs_and_crops): self.dataset = dataset self.root = dataset.root self.img_dir = dataset.img_dir self.imgs, self.crops = zip(*list_of_imgs_and_crops) self.nimg = len(self.imgs) def get_image(self, img_idx): # even if the image have multiple signage polygon? org_img = dataset.get_image(self, img_idx) crop_signs = crop_image(org_img, self.crops[img_idx]) return crop_signs[0] # temporary use one, but have to change for multiple signages def get_filename(self, img_idx): return self.dataset.get_filename(img_idx) def get_key(self, img_idx): return self.dataset.get_key(img_idx) def crop_image(self, img, polygons): import cv2 crop_signs=[] if len(polygons)==0: pdb.set_trace() for Polycc in polygons: rgbimg = img.copy() rgbimg = np.array(rgbimg) # pil to cv2 Poly_s = np.array(Polycc) ## rearrange if Poly_s[0, 1]<Poly_s[1, 1]: temp = Poly_s[1, :].copy() Poly_s[1, :]= Poly_s[0, :] Poly_s[0, :]=temp if Poly_s[2, 1]>Poly_s[3, 1]: temp = Poly_s[3, :].copy() Poly_s[3, :]= Poly_s[2, :] Poly_s[2, :]=temp cy_s = np.mean( Poly_s[:,0] ) cx_s = np.mean( Poly_s[:,1] ) w_s = np.abs( Poly_s[0][1]-Poly_s[1][1] ) h_s = np.abs( Poly_s[0][0]-Poly_s[2][0] ) Poly_d = np.array([(cy_s-h_s/2, cx_s+w_s/2), (cy_s-h_s/2, cx_s-w_s/2), (cy_s+h_s/2, cx_s-w_s/2), (cy_s+h_s/2, cx_s+w_s/2)]).astype(np.int) M, mask= cv2.findHomography(Poly_s, Poly_d) warpimg = Image.fromarray(cv2.warpPerspective(rgbimg, M, (645,800))) # from cv2 type rgbimg crop_sign = warpimg.crop([np.min(Poly_d[:,0]), np.min(Poly_d[:,1]), np.max(Poly_d[:,0]), np.max(Poly_d[:,1])]) ### append crop_signs.append(crop_sign) return crop_signs