''' Parallel data loading functions ''' import sys import time import theano import numpy as np import traceback from PIL import Image from six.moves import queue from multiprocessing import Process, Event from lib.config import cfg from lib.data_augmentation import preprocess_img from lib.data_io import get_voxel_file, get_rendering_file from lib.binvox_rw import read_as_3d_array def print_error(func): '''Flush out error messages. Mainly used for debugging separate processes''' def func_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except: traceback.print_exception(*sys.exc_info()) sys.stdout.flush() return func_wrapper class DataProcess(Process): def __init__(self, data_queue, data_paths, repeat=True): ''' data_queue : Multiprocessing queue data_paths : list of data and label pair used to load data repeat : if set True, return data until exit is set ''' super(DataProcess, self).__init__() # Queue to transfer the loaded mini batches self.data_queue = data_queue self.data_paths = data_paths self.num_data = len(data_paths) self.repeat = repeat # Tuple of data shape self.batch_size = cfg.CONST.BATCH_SIZE self.exit = Event() self.shuffle_db_inds() def shuffle_db_inds(self): # Randomly permute the training roidb if self.repeat: self.perm = np.random.permutation(np.arange(self.num_data)) else: self.perm = np.arange(self.num_data) self.cur = 0 def get_next_minibatch(self): if (self.cur + self.batch_size) >= self.num_data and self.repeat: self.shuffle_db_inds() db_inds = self.perm[self.cur:min(self.cur + self.batch_size, self.num_data)] self.cur += self.batch_size return db_inds def shutdown(self): self.exit.set() @print_error def run(self): iteration = 0 # Run the loop until exit flag is set while not self.exit.is_set() and self.cur <= self.num_data: # Ensure that the network sees (almost) all data per epoch db_inds = self.get_next_minibatch() data_list = [] label_list = [] for batch_id, db_ind in enumerate(db_inds): datum = self.load_datum(self.data_paths[db_ind]) label = self.load_label(self.data_paths[db_ind]) data_list.append(datum) label_list.append(label) batch_data = np.array(data_list).astype(np.float32) batch_label = np.array(label_list).astype(np.float32) # The following will wait until the queue frees self.data_queue.put((batch_data, batch_label), block=True) iteration += 1 def load_datum(self, path): pass def load_label(self, path): pass class ReconstructionDataProcess(DataProcess): def __init__(self, data_queue, category_model_pair, background_imgs=[], repeat=True, train=True): self.repeat = repeat self.train = train self.background_imgs = background_imgs super(ReconstructionDataProcess, self).__init__( data_queue, category_model_pair, repeat=repeat) @print_error def run(self): # set up constants img_h = cfg.CONST.IMG_W img_w = cfg.CONST.IMG_H n_vox = cfg.CONST.N_VOX # This is the maximum number of views n_views = cfg.CONST.N_VIEWS while not self.exit.is_set() and self.cur <= self.num_data: # To insure that the network sees (almost) all images per epoch db_inds = self.get_next_minibatch() # We will sample # views if cfg.TRAIN.RANDOM_NUM_VIEWS: curr_n_views = np.random.randint(n_views) + 1 else: curr_n_views = n_views # This will be fed into the queue. create new batch everytime batch_img = np.zeros( (curr_n_views, self.batch_size, 3, img_h, img_w), dtype=theano.config.floatX) batch_voxel = np.zeros( (self.batch_size, n_vox, 2, n_vox, n_vox), dtype=theano.config.floatX) # load each data instance for batch_id, db_ind in enumerate(db_inds): category, model_id = self.data_paths[db_ind] image_ids = np.random.choice(cfg.TRAIN.NUM_RENDERING, curr_n_views) # load multi view images for view_id, image_id in enumerate(image_ids): im = self.load_img(category, model_id, image_id) # channel, height, width batch_img[view_id, batch_id, :, :, :] = \ im.transpose((2, 0, 1)).astype(theano.config.floatX) voxel = self.load_label(category, model_id) voxel_data = voxel.data batch_voxel[batch_id, :, 0, :, :] = voxel_data < 1 batch_voxel[batch_id, :, 1, :, :] = voxel_data # The following will wait until the queue frees self.data_queue.put((batch_img, batch_voxel), block=True) print('Exiting') def load_img(self, category, model_id, image_id): image_fn = get_rendering_file(category, model_id, image_id) im = Image.open(image_fn) t_im = preprocess_img(im, self.train) return t_im def load_label(self, category, model_id): voxel_fn = get_voxel_file(category, model_id) with open(voxel_fn, 'rb') as f: voxel = read_as_3d_array(f) return voxel def kill_processes(queue, processes): print('Signal processes') for p in processes: p.shutdown() print('Empty queue') while not queue.empty(): time.sleep(0.5) queue.get(False) print('kill processes') for p in processes: p.terminate() def make_data_processes(queue, data_paths, num_workers, repeat=True, train=True): ''' Make a set of data processes for parallel data loading. ''' processes = [] for i in range(num_workers): process = ReconstructionDataProcess(queue, data_paths, repeat=repeat, train=train) process.start() processes.append(process) return processes def get_while_running(data_process, data_queue, sleep_time=0): while True: time.sleep(sleep_time) try: batch_data, batch_label = data_queue.get_nowait() except queue.Empty: if not data_process.is_alive(): break else: continue yield batch_data, batch_label def test_process(): from multiprocessing import Queue from lib.config import cfg from lib.data_io import category_model_id_pair cfg.TRAIN.PAD_X = 10 cfg.TRAIN.PAD_Y = 10 data_queue = Queue(2) category_model_pair = category_model_id_pair(dataset_portion=[0, 0.1]) data_process = ReconstructionDataProcess(data_queue, category_model_pair) data_process.start() batch_img, batch_voxel = data_queue.get() kill_processes(data_queue, [data_process]) if __name__ == '__main__': test_process()