# -*- coding: utf-8 -*- # This file includes functionality for (multi-threaded) batch generation. # Author: Stefan Kahl, 2018, University of Technology Chemnitz import sys sys.path.append("..") import numpy as np import config as cfg from utils import image RANDOM = cfg.getRandomState() #################### IMAGE HANDLING ##################### def loadImageAndTarget(sample, augmentation): # Load image img = image.openImage(sample[0], cfg.IM_DIM) # Resize Image img = image.resize(img, cfg.IM_SIZE[0], cfg.IM_SIZE[1], mode=cfg.RESIZE_MODE) # Do image Augmentation if augmentation: img = image.augment(img, cfg.IM_AUGMENTATION, cfg.AUGMENTATION_COUNT, cfg.AUGMENTATION_PROBABILITY) # Prepare image for net input img = image.normalize(img, cfg.ZERO_CENTERED_NORMALIZATION) img = image.prepare(img) # Get target label = sample[1] index = cfg.CLASSES.index(label) target = np.zeros((len(cfg.CLASSES)), dtype='float32') target[index] = 1.0 return img, target #################### BATCH HANDLING ##################### def getDatasetChunk(split): #get batch-sized chunks of image paths for i in xrange(0, len(split), cfg.BATCH_SIZE): yield split[i:i+cfg.BATCH_SIZE] def getNextImageBatch(split, augmentation=True): #fill batch for chunk in getDatasetChunk(split): #allocate numpy arrays for image data and targets x_b = np.zeros((cfg.BATCH_SIZE, cfg.IM_DIM, cfg.IM_SIZE[1], cfg.IM_SIZE[0]), dtype='float32') y_b = np.zeros((cfg.BATCH_SIZE, len(cfg.CLASSES)), dtype='float32') ib = 0 for sample in chunk: try: #load image data and class label from path x, y = loadImageAndTarget(sample, augmentation) #pack into batch array x_b[ib] = x y_b[ib] = y ib += 1 except: continue #trim to actual size x_b = x_b[:ib] y_b = y_b[:ib] #instead of return, we use yield yield x_b, y_b #Loading images with CPU background threads during GPU forward passes saves a lot of time #Credit: J. Schlüter (https://github.com/Lasagne/Lasagne/issues/12) def threadedGenerator(generator, num_cached=32): import Queue queue = Queue.Queue(maxsize=num_cached) sentinel = object() # guaranteed unique reference #define producer (putting items into queue) def producer(): for item in generator: queue.put(item) queue.put(sentinel) #start producer (in a background thread) import threading thread = threading.Thread(target=producer) thread.daemon = True thread.start() #run as consumer (read items from queue, in current thread) item = queue.get() while item is not sentinel: yield item try: queue.task_done() item = queue.get() except: break def nextBatch(split, augmentation=True, threaded=True): if threaded: for x, y in threadedGenerator(getNextImageBatch(split, augmentation)): yield x, y else: for x, y in getNextImageBatch(split, augmentation): yield x, y