""" This module implements the data structures returned by Data.load() to support the unique data loading requirements of different deep learning libraries. """ import random import itertools from typing import Callable, Any import numpy as np def _rand_batch_ixs(num_samples: int, batch_size: int, fetch_size: int, random_seed: int): """A generator which yields a list of tuples (offset, size) in random order. This list will be used by the data loader to efficiently load samples and pass it to the model during training. :param num_samples: Number of available samples. :param batch_size: The size of the batch to fill. :param fetch_size: Desired fetch_size. :param random_seed: RNG seed. """ rng = random.Random(random_seed) batch, batch_count = [], 0 while True: if fetch_size * 3 < num_samples: # if the number of samples is too small, having a random offset # makes no sense offset = rng.randint(0, fetch_size) else: offset = 0 ixs = list(range(offset, num_samples - offset, fetch_size)) rng.shuffle(ixs) # collect enough samples to fill the batch while ixs: next_fetch = ixs.pop(0) # calculate the next fetch size depending on the samples remaining # and the number of samples required to fill the batch next_fetch_size = min(fetch_size, num_samples - next_fetch, batch_size - batch_count) batch.append((next_fetch, next_fetch_size)) batch_count += next_fetch_size if batch_count == batch_size: yield batch batch, batch_count = [], 0 def _ser_batch_ixs(num_samples, batch_size): """A generator which yields a list of tuples (offset, size) in serial order. :param num_samples: Number of available samples. :param batch_size: The size of the batch to fill. """ current_index = 0 batch, batch_count = [], 0 while True: next_fetch = current_index next_fetch_size = min(batch_size - batch_count, num_samples - next_fetch) batch.append((next_fetch, next_fetch_size)) batch_count += next_fetch_size if batch_count == batch_size: # If we have enough samples to fill the batch size, yield # the indices and reset the batch count. yield batch batch, batch_count = [], 0 current_index += next_fetch_size if current_index == num_samples: current_index = 0 def _pumpfn(ix_gen): while True: yield from next(ix_gen, None) class BatchView: # pylint: disable=R0902 """Generator that returns data as batches (optionally infinite). """ def __init__(self, # pylint: disable=R0913 loader, split, layout: str = 'tuples', batch_size: int = 64, fetch_size: int = 8, infinite: bool = False, with_meta: bool = False, randomize: bool = False, random_seed: int = 42, transform_x: Callable[[Any], Any] = lambda x: x, transform_y: Callable[[Any], Any] = lambda y: y): self.loader = loader self.split = split self.loader.begin_read_samples() num_samples = self.loader.num_samples(self.split) self.loader.end_read_samples() self.infinite = infinite self.with_meta = with_meta self.transform_x = transform_x self.transform_y = transform_y self.layout = layout self.num_batches = num_samples // batch_size self.current_batch = 0 if randomize: ix_fn = lambda: _rand_batch_ixs(num_samples, batch_size, fetch_size, random_seed) else: ix_fn = lambda: _ser_batch_ixs(num_samples, batch_size) # We generate two identical ix generators - one for the view and # one for the loader self.ix_gen = ix_fn() self.loader.pump(self.split, _pumpfn(ix_fn())) def __iter__(self): self.current_batch = 0 return self def __len__(self): return self.num_batches def __next__(self): if self.current_batch >= self.num_batches and not self.infinite: raise StopIteration # BEGIN loading samples from the data loader self.loader.begin_read_samples() res = [] for index, n_samples in next(self.ix_gen): samples = self.loader.read_samples(self.split, index, n_samples) for sample in samples: # pylint: disable=C0103 x, y, m = sample.x, sample.y, sample.meta x, y = self.transform_x(x), self.transform_y(y) res.append((x, y, m) if self.with_meta else (x, y)) self.loader.end_read_samples() # END loading samples self.current_batch += 1 # rearrange the result according to the configured layout if self.layout in ('lists', 'arrays'): res = tuple(map(list, zip(*res))) if self.layout == 'arrays': # pylint: disable=C0103 xs, ys, *meta = res res = tuple([np.array(xs), np.array(ys)] + meta) return res class IteratorView: # pylint: disable=R0902 """Generator that returns one sample at a time. """ def __init__(self, # pylint: disable=R0913 loader, split, fetch_size=8, infinite=False, with_meta=False, randomize=False, random_seed=42, transform_x=lambda x: x, transform_y=lambda y: y): self.loader = loader self.split = split self.loader.begin_read_samples() self.num_samples = self.loader.num_samples(self.split) self.loader.end_read_samples() self.infinite = infinite self.with_meta = with_meta self.transform_x = transform_x self.transform_y = transform_y self.fetch_size = fetch_size self.rng = random.Random(random_seed) if randomize else None self.ixs = None self.current_index = 0 self._shuffle() def _shuffle(self): self.ixs = range(0, self.num_samples) if self.rng: # When randomizing sample order, make sure to lay out samples # according to fetch size to improve performance. self.ixs = [self.ixs[i:i + self.fetch_size] for i in range(0, len(self.ixs), self.fetch_size)] self.rng.shuffle(self.ixs) self.ixs = list(itertools.chain.from_iterable(self.ixs)) def __iter__(self): return self def __len__(self): return self.num_samples def __next__(self): if self.current_index >= self.num_samples: self.current_index = 0 self._shuffle() if not self.infinite: raise StopIteration index = self.ixs[self.current_index] sample = self.loader.read_samples(self.split, index, 1)[0] # pylint: disable=C0103 x, y, m = sample.x, sample.y, sample.meta x, y = self.transform_x(x), self.transform_y(y) res = (x, y, m) if self.with_meta else (x, y) self.current_index += 1 return res