"""Data module""" import signal import torch from progressbar import Bar, ProgressBar, Percentage, Timer, ETA from .noise import get_noise_var from .. import exp __author__ = 'R Devon Hjelm' __author_email__ = 'erroneus@gmail.com' class DataHandler: def __init__(self): self.dims = {} self.input_names = {} self.noise = {} self.loaders = {} self.batch = None self.noise = {} self.iterator = {} self.pbar = None self.u = 0 self.inputs = dict() def set_batch_size(self, batch_size, skip_last_batch=False): self.batch_size = batch_size self.skip_last_batch = skip_last_batch def set_inputs(self, **kwargs): self.inputs.update(**kwargs) def add_dataset(self, source, dataset_entrypoint, n_workers=4, shuffle=True, DataLoader=None): DataLoader = (DataLoader or dataset_entrypoint._dataloader_class or torch.utils.data.DataLoader) if len(dataset_entrypoint._datasets) == 0: raise ValueError('No datasets found in entrypoint') loaders = {} for k, dataset in dataset_entrypoint._datasets.items(): N = len(dataset) dataset_entrypoint._dims['N_' + k] = N if isinstance(self.batch_size, dict): try: self.batch_size[k] except KeyError: self.batch_size[k] = self.batch_size_ finally: batch_size = self.batch_size[k] else: self.batch_size_ = self.batch_size self.batch_size = {k: self.batch_size_} batch_size = self.batch_size_ loaders[k] = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=n_workers, worker_init_fn=lambda x: signal.signal(signal.SIGINT, signal.SIG_IGN)) self.dims[source] = dataset_entrypoint._dims self.input_names[source] = dataset_entrypoint._input_names self.loaders[source] = loaders def add_noise(self, key, dist=None, size=None, **kwargs): if size is None: raise ValueError dim = size if not isinstance(size, tuple): size = (size,) train_size = (self.batch_size['train'],) + size test_size = (self.batch_size['test'],) + size var = get_noise_var(dist, train_size, **kwargs) var_t = get_noise_var(dist, test_size, **kwargs) self.noise[key] = dict(train=var, test=var_t) self.dims[key] = dim def __iter__(self): return self def __next__(self): output = {} sources = self.loaders.keys() batch_size = self.batch_size[self.mode] for source in sources: data = next(self.iterators[source]) if data[0].size()[0] < batch_size: if self.skip_last_batch: raise StopIteration batch_size = data[0].size()[0] data = dict((k, v) for k, v in zip(self.input_names[source], data)) if len(sources) > 1: output[source] = data else: output.update(**data) for k, n_vars in self.noise.items(): n_var = n_vars[self.mode] n_var = n_var.sample() n_var = n_var.to(exp.DEVICE) if n_var.size()[0] != batch_size: n_var = n_var[0:batch_size] output[k] = n_var self.batch = output self.u += 1 self.update_pbar() return self.batch def next(self): return self.__next__() def __getitem__(self, item): if self.batch is None: raise RuntimeError('Batch not set') item = self.inputs.get(item, item) if item not in self.batch.keys(): raise KeyError('Data with label `{}` not found. Available: {}' .format(item, tuple(self.batch.keys()))) batch = self.batch[item] return batch def get_batch(self, *item): if self.batch is None: raise RuntimeError('Batch not set') batch = [] for i in item: if '.' in i: j, i_ = i.split('.') j = int(j) batch.append(self.batch[list(self.batch.keys())[j - 1]][i_]) elif i not in self.batch.keys(): raise KeyError('Data with label `{}` not found. Available: {}' .format(i, tuple(self.batch.keys()))) else: batch.append(self.batch[i]) if len(batch) == 1: return batch[0] else: return batch def get_dims(self, *q): if q[0] in self.dims.keys(): dims = self.dims else: key = [k for k in self.dims.keys() if k not in self.noise.keys()][0] dims = self.dims[key] try: d = [dims[q_] for q_ in q] except KeyError: raise KeyError('Cannot resolve dimensions {}, provided {}' .format(q, dims)) if len(d) == 1: return d[0] else: return d def get_label_names(self, source=None): # TODO(Devon): This needs to # incorporate specific label # names from the dataset plugin. source = source or list(self.loaders.keys())[0] names = ['{}'.format(i) for i in range(self.dims[source]['labels'])] return names def make_iterator(self, source): loader = self.loaders[source][self.mode] def iterator(): for inputs in loader: inputs = [inp.to(exp.DEVICE) for inp in inputs] inputs_ = [] for i, inp in enumerate(inputs): inputs_.append(inp) yield inputs_ return iterator() def update_pbar(self): if self.pbar: self.pbar.update(self.u) def reset(self, mode, make_pbar=True, string=''): self.mode = mode self.u = 0 if make_pbar: widgets = [string, Timer(), ' | ', Percentage(), ' | ', ETA(), Bar()] if len([len(loader[self.mode]) for loader in self.loaders.values()]) == 0: maxval = 1000 else: maxval = min(len(loader[self.mode]) for loader in self.loaders.values()) self.pbar = ProgressBar(widgets=widgets, maxval=maxval).start() else: self.pbar = None sources = self.loaders.keys() self.iterators = dict((source, self.make_iterator(source)) for source in sources)