from kerosene.torch_util import to_gpu import collections import numpy as np import os import torch import torch.utils.data from concurrent.futures.thread import ThreadPoolExecutor class SortSampler(torch.utils.data.sampler.Sampler): def __init__(self, data_source, key): self.data_source, self.key = data_source, key def __len__(self): return len(self.data_source) def __iter__(self): return iter(sorted(range(len(self.data_source)), key=self.key, reverse=True)) class SortishSampler(torch.utils.data.sampler.Sampler): def __init__(self, data_source, key, bs): self.data_source, self.key, self.bs = data_source, key, bs def __len__(self): return len(self.data_source) def __iter__(self): idxs = np.random.permutation(len(self.data_source)) sz = self.bs*50 ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)] sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) sz = self.bs ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)] max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) sort_idx = np.concatenate((ck_idx[0], sort_idx)) return iter(sort_idx) class DataLoader(object): def __init__( self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, pad_idx=0, num_workers=None, pin_memory=False, drop_last=False, pre_pad=True, half=False, transpose=False, transpose_y=False, ): self.dataset, self.batch_size, self.num_workers = dataset, batch_size, num_workers self.pin_memory, self.drop_last, self.pre_pad = pin_memory, drop_last, pre_pad self.transpose, self.transpose_y = transpose, transpose_y self.pad_idx, self.half = pad_idx, half if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None: if sampler is None: if shuffle: sampler = torch.utils.data.sampler.RandomSampler(dataset) else: sampler = torch.utils.data.sampler.SequentialSampler(dataset) batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last) if num_workers is None: self.num_workers = num_cpus() self.sampler = sampler self.batch_sampler = batch_sampler def __len__(self): return len(self.batch_sampler) def jag_stack(self, b): if len(b[0].shape) not in (1, 2): return np.stack(b) ml = max(len(o) for o in b) if min(len(o) for o in b) == ml: return np.stack(b) res = np.zeros((len(b), ml), dtype=b[0].dtype) + self.pad_idx for i, o in enumerate(b): if self.pre_pad: res[i, -len(o):] = o else: res[i, :len(o)] = o return res def np_collate(self, batch): b = batch[0] if isinstance(b, (np.ndarray, np.generic)): return self.jag_stack(batch) elif isinstance(b, (int, float)): return np.array(batch) elif isinstance(b, (str, bytes)): return batch elif isinstance(b, collections.Mapping): return {key: self.np_collate([d[key] for d in batch]) for key in b} elif isinstance(b, collections.Sequence): return [self.np_collate(samples) for samples in zip(*batch)] raise TypeError(("batch must contain numbers, dicts or lists; found {}".format(type(b)))) def get_batch(self, indices): res = self.np_collate([self.dataset[i] for i in indices]) if self.transpose: res[0] = res[0].T if self.transpose_y: res[1] = res[1].T return res def __iter__(self): if self.num_workers == 0: for batch in map(self.get_batch, iter(self.batch_sampler)): yield self._get_tensor(batch, self.pin_memory, self.half) else: with ThreadPoolExecutor(max_workers=self.num_workers) as e: # avoid py3.6 issue where queue is infinite and can result in memory exhaustion for c in chunk_iter(iter(self.batch_sampler), self.num_workers*10): for batch in e.map(self.get_batch, c): yield self._get_tensor(batch) def _get_tensor(self, batch): if isinstance(batch, (np.ndarray, np.generic)): batch = T(batch, half=self.half, cuda=False).contiguous() if self.pin_memory: batch = batch.pin_memory() return to_gpu(batch) elif isinstance(batch, (bytes, str)): return batch elif isinstance(batch, collections.Mapping): return {k: self._get_tensor(sample) for k, sample in batch.items()} elif isinstance(batch, collections.Sequence): return [self._get_tensor(sample) for sample in batch] raise TypeError(f"batch must contain numbers, dicts or lists; found {type(batch)}") def T(a, half=False, cuda=True): if not torch.is_tensor(a): a = np.array(np.ascontiguousarray(a)) if a.dtype in (np.int8, np.int16, np.int32, np.int64): a = torch.LongTensor(a.astype(np.int64)) elif a.dtype in (np.float32, np.float64): a = torch.cuda.HalfTensor(a) if half else torch.FloatTensor(a) else: raise NotImplementedError(a.dtype) return to_gpu(a, async=True) if cuda else a def chunk_iter(iterable, chunk_size): while True: chunk = [] try: for _ in range(chunk_size): chunk.append(next(iterable)) yield chunk except StopIteration: if chunk: yield chunk break def to_np(v): if isinstance(v, (np.ndarray, np.generic)): return v if isinstance(v, (list, tuple)): return [to_np(o) for o in v] if isinstance(v, torch.autograd.Variable): v = v.data if isinstance(v, torch.cuda.HalfTensor): v = v.float() return v.cpu().numpy() def num_cpus(): try: return len(os.sched_getaffinity(0)) except AttributeError: return os.cpu_count()