"""Copyright@Xianming Liu, University of Illinois at Urbana, Champaign Implementation of Triplet Data Layer """ import atexit import numpy as np from BasePythonDataLayer import BasePythonDataLayer from multiprocessing import (Process, Pipe) from utils.SampleIO import extract_sample from TripletSampler import TripletSampler __authors__ = ['Xianming Liu(liuxianming@gmail.com)'] class TripletDataLayer(BasePythonDataLayer): """Triplet Data Layer: Provide data batches for Triplet Network (using ranking loss) Data: 3 * batch_size * channels * width * height anchor image, positive and negative ones Label: Relative Similarity (Optional) Implemenation is based on BasePythonDataLayer, need to implement: 1. get_next_minibatch(self) function 2. sampleing functions for randomly sampling and guided sampling """ def setup(self, bottom, top): # setup functions from super class super(TripletDataLayer, self).setup(bottom, top) print("Using Triplet Python Data Layer") # prefetch or not: default = False self._sampling_type = self._layer_params.get('type', 'RANDOM') self._prefetch = self._layer_params.get('prefetch', False) """Construct kwargs: possible fields: k - number of candidates when hard negative sampling m - similarity graph filename for hard negative sampling n - number of iterations before hard negative sampling """ kwargs = {} for key, value in self._layer_params.iteritems(): if key.lower() in ['k', 'm', 'n']: kwargs[key.lower()] = value if self._prefetch: # using prefetch to generate mini-batches self._conn, conn = Pipe() self._prefetch_process = TripletPrefetcher( conn, self._label, self._data, self._mean, self._resize, self._batch_size, self._sampling_type, **kwargs ) print("Start Prefetching Process...") self._prefetch_process.start() def cleanup(): print("Terminating Prefetching Processs...") self._prefetch_process.terminate() self._prefetch_process.join() self._conn.close() atexit.register(cleanup) else: self._sampler = TripletSampler( self._sampling_type, self._label, **kwargs) self.reshape(bottom, top) def get_a_datum(self): """Get a datum: Sampling -> decode images -> stack numpy array """ sample = self._sampler.sample() if self._compressed: datum_ = [ extract_sample(self._data[id], self._mean, self._resize) for id in sample[:3]] else: datum_ = [self._data[id] for id in sample[:3]] if len(sample) == 4: datum_.append(sample[-1]) return datum_ def get_next_minibatch(self): if self._prefetch: # get mini-batch from prefetcher batch = self._conn.recv() else: # generate using in-thread functions data = [] p_data = [] n_data = [] label = [] for i in range(self._batch_size): datum_ = self.get_a_datum() data.append(datum_[0]) p_data.append(datum_[1]) n_data.append(datum_[2]) if len(datum_) == 4: # datum and label / margin label.append(datum_[-1]) batch = [np.array(data), np.array(p_data), np.array(n_data)] if len(label): label = np.array(label).reshape(self._batch_size, 1, 1, 1) batch.append(label) return batch class TripletPrefetcher(Process): """TripletPrefetcher: Use a separate process to sample triplets, following the same function implementations as TripletDataLayer """ def __init__(self, conn, labels, data, mean, resize, batch_size, # samping related parameters sampling_type, **kwargs): super(TripletPrefetcher, self).__init__() self._conn = conn self._labels = labels self._data = data if type(self._data[0]) is not str: self._compressed = False else: self._compressed = True self._batch_size = batch_size self._mean = mean self._resize = resize self._sampling_type = sampling_type # kwargs is a dictionary related with sampling self._sampler = TripletSampler( self._sampling_type, self._labels, **kwargs) def type(self): return "TripletPrefetcher" def get_a_datum(self): """Get a datum: Sampling -> decode images -> stack numpy array """ sample = self._sampler.sample() if self._compressed: datum_ = [ extract_sample(self._data[id], self._mean, self._resize) for id in sample[:3]] else: datum_ = [self._data[id] for id in sample[:3]] if len(sample) == 4: datum_.append(sample[-1]) return datum_ def get_next_minibatch(self): # generate using in-thread functions data = [] p_data = [] n_data = [] label = [] for i in range(self._batch_size): datum_ = self.get_a_datum() # print(len(datum_), ":".join([str(x.shape) for x in datum_])) data.append(datum_[0]) p_data.append(datum_[1]) n_data.append(datum_[2]) if len(datum_) == 4: # datum and label / margin label.append(datum_[-1]) batch = [np.array(data), np.array(p_data), np.array(n_data)] if len(label): label = np.array(label).reshape(self._batch_size, 1, 1, 1) batch.append(label) return batch def run(self): print("Prefetcher Started...") while True: batch = self.get_next_minibatch() self._conn.send(batch)