import json import logging import os.path as osp from Queue import Empty, Queue from threading import Thread, current_thread import numpy as np from config import SHAPENET_IM from loader import read_camera, read_depth, read_im, read_quat, read_vol def get_split(split_js='data/splits.json'): dir_path = osp.dirname(osp.realpath(__file__)) with open(osp.join(dir_path, split_js), 'r') as f: js = json.load(f) return js class ShapeNet(object): def __init__(self, im_dir=SHAPENET_IM, split_file='data/splits.json', vox_dir=None, shape_ids=None, num_renders=20, rng_seed=0): self.vox_dir = vox_dir self.im_dir = im_dir self.split_file = split_file self.splits_all = get_split(split_file) self.shape_ids = (self.splits_all.keys() if shape_ids is None else shape_ids) self.splits = {k: self.splits_all[k] for k in self.shape_ids} self.shape_cls = [ self.splits[x]['name'].split(',')[0] for x in self.shape_ids ] self.rng = rng_seed self.num_renders = num_renders self.load_func = { 'im': self.get_im, 'depth': self.get_depth, 'K': self.get_K, 'R': self.get_R, 'quat': self.get_quat, 'vol': self.get_vol, 'shape_id': self.get_sid, 'model_id': self.get_mid, 'view_idx': self.get_view_idx } self.all_items = self.load_func.keys() self.logger = logging.getLogger('mview3d.' + __name__) np.random.seed(self.rng) def get_mids(self, sid): return self.splits[sid] def get_smids(self, split): smids = [] for k, v in self.splits.iteritems(): smids.extend([(k, m) for m in v[split]]) smids = np.random.permutation(smids) return smids def get_sid(self, sid, mid, idx=None): return np.array([sid]) def get_view_idx(self, sid, mid, idx): return idx def get_mid(self, sid, mid, idx=None): return np.array([mid]) def get_K(self, sid, mid, idx): rand_idx = idx cams = [] for ix in rand_idx: f = osp.join(self.im_dir, sid, mid, 'camera_{:d}.mat'.format(ix)) cams.append(read_camera(f)) camK = np.stack([c[0] for c in cams], axis=0) return camK def get_R(self, sid, mid, idx): rand_idx = idx cams = [] for ix in rand_idx: f = osp.join(self.im_dir, sid, mid, 'camera_{:d}.mat'.format(ix)) cams.append(read_camera(f)) camR = np.stack([c[1] for c in cams], axis=0) return camR def get_quat(self, sid, mid, idx): rand_idx = idx cams = [] for ix in rand_idx: f = osp.join(self.im_dir, sid, mid, 'camera_{:d}.mat'.format(ix)) cams.append(read_quat(f)) camq = np.stack(cams, axis=0) return camq def get_depth(self, sid, mid, idx): rand_idx = idx depths = [] for ix in rand_idx: f = osp.join(self.im_dir, sid, mid, 'depth_{:d}.png'.format(ix)) depths.append(read_depth(f)) return np.stack(depths, axis=0) def get_im(self, sid, mid, idx): rand_idx = idx ims = [] for ix in rand_idx: f = osp.join(self.im_dir, sid, mid, 'render_{:d}.png'.format(ix)) ims.append(read_im(f)) return np.stack(ims, axis=0) def get_vol(self, sid, mid, idx=None, tsdf=False): if self.vox_dir is None: self.logger.error('Voxel dir not defined') f = osp.join(self.vox_dir, sid, mid) return read_vol(f, tsdf) def fetch_data(self, smids, items, im_batch): with self.coord.stop_on_exception(): while not self.coord.should_stop(): data = {} try: data_idx = self.queue_idx.get(timeout=0.5) except Empty: self.logger.debug('Index queue empty - {:s}'.format( current_thread().name)) continue view_idx = np.random.choice( self.num_renders, size=(im_batch, ), replace=False) sid, mid = smids[data_idx] for i in items: data[i] = self.load_func[i](sid, mid, view_idx) self.queue_data.put(data) if self.loop_data: self.queue_idx.put(data_idx) def init_queue(self, smids, im_batch, items, coord, nepochs=None, qsize=32, nthreads=4): self.coord = coord self.queue_data = Queue(maxsize=qsize) if nepochs is None: nepochs = 1 self.loop_data = True else: self.loop_data = False self.total_items = nepochs * len(smids) self.queue_idx = Queue(maxsize=self.total_items) for nx in range(nepochs): for rx in range(len(smids)): self.queue_idx.put(rx) self.qthreads = [] self.logger.info('Starting {:d} prefetch threads'.format(nthreads)) for qx in range(nthreads): worker = Thread( target=self.fetch_data, args=(smids, items, im_batch)) worker.start() self.coord.register_thread(worker) self.qthreads.append(worker) def close_queue(self, e=None): self.logger.debug('Closing queue') self.coord.request_stop(e) try: while True: self.queue_idx.get(block=False) except Empty: self.logger.debug('Emptied idx queue') try: while True: self.queue_data.get(block=False) except Empty: self.logger.debug("Emptied data queue") def next_batch(self, items, batch_size, timeout=0.5): data = [] cnt = 0 while cnt < batch_size: try: dt = self.queue_data.get(timeout=timeout) self.total_items -= 1 data.append(dt) except Empty: self.logger.debug('Example queue empty') if self.total_items <= 0 and not self.loop_data: # Exhausted all data self.close_queue() break else: continue cnt += 1 if len(data) == 0: return batch_data = {} for k in items: batch_data[k] = [] for dt in data: batch_data[k].append(dt[k]) batched = np.stack(batch_data[k]) batch_data[k] = batched return batch_data def reset(self): np.random.seed(self.rng)