# coding=utf-8 # Copyright 2020 The Trax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Trax input pipeline.""" import math import random from absl import logging import gin import numpy as np from trax.fastmath import numpy as jnp class Inputs(object): """Inputs bundle. Inputs bundle holds input streams and shapes for a training run. It contains stream-creating functions that return python generators of (input_batch, target_batch) tuples. * train_stream: training data that will be used for training may include all the augmentation or selection the training wants the shape of examples is [batch_fn.batch_size, ...] * train_eval_stream: training data used for evaluation examples from training data but usually without augmentation the shape of examples is [batch_fn.eval_batch_size, ...] * eval_stream: evaluation data stream examples from evaluation data, usually without augmentation the shape of examples is [batch_fn.eval_batch_size, ...] * input_shape: the shape of inputs the [...] above, without batch size * input_dtype: the data type of inputs * target_shape: the shape of targets the [...] above, without batch size * target_dtype: the data type of targets """ def __init__(self, train_stream, eval_stream=None, train_eval_stream=None): """Initialize a new set of inputs. Args: train_stream: a function taking n_devices (an int) and returning a python generator of training batches. eval_stream: a function taking n_devices (an int) and returning a python generator of validation batches; if None, then the training generator will be used for evaluation. train_eval_stream: a function taking n_devices (an int) and returning a python generator of batches from the training set used for evaluation (if None, use train_stream). """ if not callable(train_stream): raise ValueError('Trax Inputs should be initialized with a function. ' 'Did you forget the n_devices argument? If your inputs ' 'do not use it, try lambda _: [your-inputs].') self._train_stream = train_stream self._eval_stream = eval_stream or self._train_stream # TODO(lukaszkaiser): should we get rid of this one day? self._train_eval_stream = train_eval_stream or self._train_stream # Peek into the train stream to get an example shape. example_train_batch = next(train_stream(1)) self._input_shape = tuple(example_train_batch[0].shape)[1:] self._input_dtype = example_train_batch[0].dtype self._target_shape = tuple(example_train_batch[-1].shape)[1:] self._target_dtype = example_train_batch[-1].dtype self._example_shape = [x.shape for x in example_train_batch] self._example_dtype = [x.dtype for x in example_train_batch] def train_stream(self, n_devices): return self._train_stream(n_devices) def eval_stream(self, n_devices): return self._eval_stream(n_devices) def train_eval_stream(self, n_devices): return self._train_stream(n_devices) @property def input_shape(self): """Example input shape, without batch dimension.""" return self._input_shape @property def target_shape(self): """Example target shape, without batch dimension.""" return self._target_shape @property def input_dtype(self): """Dtype of the input.""" return self._input_dtype @property def target_dtype(self): """Dtype of the target.""" return self._target_dtype @property def example_shape_dtype(self): """Shape and Dtype of an example batch.""" return self._example_shape, self._example_dtype # Batching and input pipeline creation helpers. @gin.configurable() def batcher(data_streams=gin.REQUIRED, variable_shapes=True, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, # TODO(afrozm): Unify padding logic. id_to_mask=None, strict_pad_on_len=False): """Batcher: create trax Inputs from single-example data-streams.""" # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. # For now leaving the arguments as in batch_fn to reduce gin config changes. if callable(data_streams): # If we pass a function, e.g., through gin, call. train_stream, eval_stream = data_streams() else: train_stream, eval_stream = data_streams # pylint: disable=g-long-lambda batch_train_stream = lambda n_devices: batch_fn( train_stream(), True, n_devices, variable_shapes, batch_size_per_device, batch_size, eval_batch_size, bucket_length, buckets, buckets_include_inputs_in_length, batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) batch_eval_stream = lambda n_devices: batch_fn( eval_stream(), False, n_devices, variable_shapes, batch_size_per_device, batch_size, eval_batch_size, bucket_length, buckets, buckets_include_inputs_in_length, batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) batch_train_eval_stream = lambda n_devices: batch_fn( train_stream(), False, n_devices, variable_shapes, batch_size_per_device, batch_size, eval_batch_size, bucket_length, buckets, buckets_include_inputs_in_length, batch_shuffle_size, max_eval_length, id_to_mask, strict_pad_on_len) # pylint: enable=g-long-lambda return Inputs(train_stream=batch_train_stream, eval_stream=batch_eval_stream, train_eval_stream=batch_train_eval_stream) def batch_fn(dataset, training, n_devices, variable_shapes, batch_size_per_device=32, batch_size=None, eval_batch_size=32, bucket_length=32, buckets=None, buckets_include_inputs_in_length=False, batch_shuffle_size=None, max_eval_length=None, id_to_mask=None, strict_pad_on_len=False): """Batching function.""" # TODO(lukaszkaiser, jonni): revisit arguments, their semantics and naming. # After that, create a proper doc-string; we may also not need to pass both # training and eval arguments here, as batcher calls the function separately # now and it's not under gin-config any more -- consider reducing args. batch_size = batch_size or batch_size_per_device * n_devices # If bucketing is not specified, check if target shapes are variable. cur_batch_size = batch_size if training else eval_batch_size # Make cur_batch_size divisible by n_devices. cur_batch_size = max(cur_batch_size // n_devices, 1) * n_devices # Create heuristic buckets if none are specified. if buckets is None: logging.info('Heuristically setting bucketing to %s based on shapes ' 'of target tensors.', variable_shapes) if variable_shapes: bucket_boundaries = [bucket_length // 4, bucket_length // 2, bucket_length, bucket_length * 2, bucket_length * 4, bucket_length * 8, bucket_length * 16] if not training: max_eval_length = max_eval_length or bucket_length * 32 # Set last bucket boundary to be max_eval_length, cut off boundaries # that are larger than this. bucket_boundaries = ( [b for b in bucket_boundaries if b < max_eval_length] + [max_eval_length] ) bucket_boundaries.append(max_eval_length) # We will pad to boundaries which pads to bucket_boundary - 1: add 1 here. bucket_boundaries = [b + 1 for b in bucket_boundaries] bucket_batch_sizes = [cur_batch_size * 4, cur_batch_size * 2, cur_batch_size, cur_batch_size // 2, cur_batch_size // 4, cur_batch_size // 8, cur_batch_size // 16, 1] if not training: # The last bucket batch size is always 1, but the one-but-last is # sized to accomodate the final length = bucket_boundaries[-1], which # we changed for eval above -- so adjusting here too. # Resize if needed, since bucket_batch_sizes may not be the same size # anymore. bucket_batch_sizes = bucket_batch_sizes[:len(bucket_boundaries)] + [1] bucket_batch_sizes[-2] = cur_batch_size // max_eval_length # Make batch sizes divisible by n_devices. bucket_batch_sizes = [max(b // n_devices, 1) * n_devices for b in bucket_batch_sizes] buckets = (bucket_boundaries, bucket_batch_sizes) if buckets: logging.info('Bucketing with buckets %s.', str(buckets)) def example_length(x): """The length function used by bucket_by_sequence_length to bucket.""" # The input x is a tuple to go on the stack, typically either # (input, target) or (input, target, mask). example_inputs, target = x[0], x[1] # Length is the shape of axis 0 here (no batch yet). other_length = 0 # We include input length only if asked. if buckets_include_inputs_in_length: other_length = example_inputs.shape[0] return max(target.shape[0], other_length) boundaries, batch_sizes = buckets dataset = bucket_by_length( dataset, example_length, boundaries, batch_sizes, strict_pad_on_len) else: logging.info('Not Bucketing cur_batch_size %d.', cur_batch_size) dataset = batch_data(dataset, cur_batch_size) if training and batch_shuffle_size is not None: dataset = shuffle_data(dataset, batch_shuffle_size) return add_loss_weights(dataset, id_to_mask) def shuffle_data(samples, queue_size): """Shuffles a sample stream using a random-out next-in queue of given size. Args: samples: Stream of samples for eventual use as training data or eval data. queue_size: Minimum number of samples within which the streamed shuffling takes place. Yields: Shuffled stream of samples, ready for further processing, e.g., grouping into batches. """ if queue_size < 1: raise ValueError(f'Arg queue_size ({queue_size}) is less than 1.') if queue_size == 1: logging.warning('Queue size of 1 results in no shuffling.') queue = [] try: # Prep: fill the queue. for _ in range(queue_size): queue.append(next(samples)) # Core streaming shuffle: yield sample from random location in queue, then # fill that location with new sample from input stream. for sample in samples: i = np.random.randint(queue_size) yield queue[i] queue[i] = sample except StopIteration: # Only get here if the initial queue fill fails. logging.warning( 'Not enough samples (%d) to fill initial queue (size %d).', len(queue), queue_size) # No new samples coming in; shuffle and drain the queue. np.random.shuffle(queue) for sample in queue: yield sample def batch_data(generator, batch_size): """Batch and pad generator as in tf.data.Dataset.padded_batch.""" buf = [] # TODO(lukaszkaiser): convert to ValueError assert batch_size > 0, f'Batch size must be positive, but is {batch_size}' for example in generator: buf.append(example) # Examples are tuples of tensors. if len(buf) == batch_size: # buf is a list of tuples, e.g., [(in1, tgt1), (in2, tgt2), (in3, tgt3)] # batch is a tuple of arrays: ([in1, in2, in3], [tgt1, tgt2, tgt3]) batch = tuple(np.stack(x) for x in zip(*buf)) # Note that it's the same shape as each example with added batch dim. yield batch buf = [] def pad_to_max_dims(tensors, boundary=None, strict_pad_on_len=False): """Pad a tuple of tensors to a joint dimension and return their batch. For example, a pair of tensors of shape (2, 10) and (3, 9) will be padded to (3, 10) both and the returned tensor will have shape (2, 3, 10). When boundary is specified, we try to pad all unknown dimensions to boundary if possible, which can help reduce the number of different shapes occuring in the tensors and speed up XLA compilation. So, for example, a pair of tensors of shapes (8, 10), (8, 9) with boundary=12 will be padded to (8, 12). One special case occurs when boundary is much higher than the padding length that we'd use without boundary. For example, tensors (2, 10) and (3, 9) with boundary=12 could end up padded to (12, 12), but this is very wasteful in the first dimension. In that case, we will use the closest power-of-2 instead of the boundary, so the we will end up padding to (4, 12) instead of (12, 12). Args: tensors: a tuple or list of tensors to pad boundary: int or None; if given, expand the padded dimensions to this size strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] strictly as a multiple of boundary. Returns: a tensor, the tensors padded together """ # TODO(afrozm): Unify this later. if ((boundary is not None) and (strict_pad_on_len or isinstance(boundary, (list, tuple)))): ndim = tensors[0].ndim if not isinstance(boundary, (list, tuple)): boundary = [boundary] * ndim if ndim != len(boundary): raise ValueError(f'ndim != len(boundary) - ' f'ndim({ndim}) vs boundary({boundary}) ' f'len(boundary) = {len(boundary)}.') max_len_per_dim = [0] * ndim for tensor in tensors: max_len_per_dim = [ max(e, s) for e, s in zip(tensor.shape, max_len_per_dim)] # Round everything up to a multiple of boundary in the respective dimension. len_per_dim = [ max_len_per_dim[i] if not b else b * math.ceil(max_len_per_dim[i] / b) for i, b in enumerate(boundary)] padded_tensors = [ np.pad(t, [(0, len_per_dim[i] - t.shape[i]) for i in range(ndim)], mode='constant', constant_values=t.dtype.type(0)) for t in tensors] return np.stack(padded_tensors) max_len_to_pad = [] padding_needed = False dim = len(tensors[0].shape) for i in range(dim): max_len = max([t.shape[i] for t in tensors]) min_len = min([t.shape[i] for t in tensors]) if max_len == min_len: # No padding needed. max_len_to_pad.append(max_len) elif boundary is None: max_len_to_pad.append(max_len) padding_needed = True else: padding_needed = True cur_boundary = max(max_len, boundary) if 2 * max_len < cur_boundary: cur_boundary = 2**int(np.ceil(np.log2(max_len))) max_len_to_pad.append(cur_boundary) if not padding_needed: return np.stack(tensors) padded_tensors = [] for t in tensors: pad_widths = [(0, max_len_to_pad[i] - t.shape[i]) for i in range(dim)] padded_t = np.pad(t, pad_widths, mode='constant', constant_values=t.dtype.type(0)) padded_tensors.append(padded_t) return np.stack(padded_tensors) def bucket_by_length(generator, length_fn, boundaries, batch_sizes, strict_pad_on_len=False): """Bucket by length, like tf.data.experimental.bucket_by_sequence_length. Args: generator: python generator to draw data from. length_fn: a function taking the example and returning the length. boundaries: a list of bucket boundaries. batch_sizes: a list of batch sizes. strict_pad_on_len: bool; if true we pad on the length dimension, dim[0] strictly as a multiple of boundary. Yields: An input batch, which comes from one of the buckets. """ buckets = [[] for _ in range(len(batch_sizes))] boundaries = boundaries + [math.inf] # Max boundary is unlimited. for example in generator: length = length_fn(example) # `bucket_idx` will always be < len(boundaries), since boundaries is right # padded by `math.inf`. bucket_idx = min([i for i, b in enumerate(boundaries) if length < b]) buckets[bucket_idx].append(example) if len(buckets[bucket_idx]) == batch_sizes[bucket_idx]: batch = zip(*buckets[bucket_idx]) boundary = boundaries[bucket_idx] - 1 boundary = None if boundary == math.inf else boundary padded_batch = tuple( pad_to_max_dims(x, boundary, strict_pad_on_len) for x in batch) yield padded_batch buckets[bucket_idx] = [] def add_loss_weights(generator, id_to_mask=None): """Add weights to inputs without weights and masks by id if requested. The generator stream is augmented in the following way: - If the stream consists of pairs `(inputs, targets)`, a loss mask is added that is creates as a tensor of ones of the same shape as targets. - If `id_to_mask` is not `None`, and the stream (after the previous point) has triples `(inputs, targets, weights)`, the weights are multipled by a 0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise). Args: generator: Stream of tuples. id_to_mask: If not None, int-valued id that represents padding, as opposed to true target id's. Yields: Examples from the augmented stream. """ for example in generator: if len(example) > 3 or len(example) < 2: assert id_to_mask is None, 'Cannot automatically mask this stream.' yield example else: if len(example) == 2: weights = np.ones_like(example[1]).astype(np.float32) else: weights = example[2].astype(np.float32) mask = 1.0 - np.equal(example[1], id_to_mask).astype(np.float32) weights *= mask yield (example[0], example[1], weights) # Example input functions. @gin.configurable() def random_inputs( input_shape=gin.REQUIRED, input_dtype=jnp.int32, input_range=(0, 255), output_shape=gin.REQUIRED, output_dtype=jnp.int32, output_range=(0, 9)): """Make random Inputs for debugging. Args: input_shape: the shape of inputs (including batch dimension). input_dtype: the type of the inputs (int32 by default). input_range: the range of inputs (defaults to (0, 255)). output_shape: the shape of outputs (including batch dimension). output_dtype: the type of the outputs (int32 by default). output_range: the range of outputs (defaults to (0, 9)). Returns: trax.inputs.Inputs """ def random_minibatches(n_devices): """Generate a stream of random mini-batches.""" assert input_range[0] % n_devices == 0 if input_dtype in [jnp.float16, jnp.float32, jnp.float64]: rand = np.random.uniform else: rand = np.random.random_integers while True: inp = rand(input_range[0], input_range[1], input_shape) inp = inp.astype(input_dtype) out = rand(output_range[0], output_range[1], output_shape) out = out.astype(output_dtype) yield inp, out return Inputs(random_minibatches) def _pad_to_multiple_of(x, y, axis): """Pads x to multiple of y on the given axis.""" pad_len = np.ceil(x.shape[axis] / float(y)) * y pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (0, int(pad_len - x.shape[axis])) return np.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) @gin.configurable() def sequence_copy_inputs( vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, reverse=False, pad_to_multiple=32): """Inputs for the sequence copy problem: 0w0w for w in [1..vocab_size-1]*. Args: vocab_size: how many symbols to use. batch_size: how large are the batches. train_length: maximum length of w for training. eval_min_length: minimum length of w for eval. eval_max_length : maximum length of w for eval. reverse: bool (optional, false by default): reverse the second sequence. pad_to_multiple: int, pad length to be multiple of this number. Returns: trax.inputs.Inputs """ def random_minibatches(length_list): """Generate a stream of random mini-batches.""" while True: length = random.choice(length_list) assert length % 2 == 0 w_length = (length // 2) - 1 w = np.random.randint(low=1, high=vocab_size-1, size=(batch_size, w_length)) zero = np.zeros([batch_size, 1], np.int32) loss_weights = np.concatenate([np.zeros((batch_size, w_length+2)), np.ones((batch_size, w_length))], axis=1) if reverse: x = np.concatenate([zero, w, zero, jnp.flip(w, axis=1)], axis=1) else: x = np.concatenate([zero, w, zero, w], axis=1) x = _pad_to_multiple_of(x, pad_to_multiple, 1) loss_weights = _pad_to_multiple_of(loss_weights, pad_to_multiple, 1) yield (x, x, loss_weights) # Here inputs and targets are the same. train_lengths = [2*(i+2) for i in range(train_length - 1)] eval_lengths = [2*(i+1) for i in range(eval_min_length, eval_max_length)] return Inputs( train_stream=lambda _: random_minibatches(train_lengths), eval_stream=lambda _: random_minibatches(eval_lengths) ) def lower_endian_to_number(l, base): """Helper function: convert a list of digits in the given base to a number.""" return sum([d * (base**i) for i, d in enumerate(l)]) def number_to_lower_endian(n, base): """Helper function: convert a number to a list of digits in the given base.""" if n < base: return [n] return [n % base] + number_to_lower_endian(n // base, base) def random_number_lower_endian(length, base): """Helper function: generate a random number as a lower-endian digits list.""" if length == 1: # Last digit can be 0 only if length is 1. return [np.random.randint(base)] prefix = [np.random.randint(base) for _ in range(length - 1)] return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0. @gin.configurable() def addition_inputs( vocab_size=gin.REQUIRED, batch_size=gin.REQUIRED, train_length=gin.REQUIRED, eval_min_length=gin.REQUIRED, eval_max_length=gin.REQUIRED, pad_to_multiple=32): """Inputs for the add problem: <S>x+y<S>(x+y). Args: vocab_size: how many symbols to use. batch_size: how large are the batches. train_length: maximal length of w for training. eval_min_length: minimal length of w for eval. eval_max_length: maximal length of w for eval. pad_to_multiple: int, pad length to be multiple of this number. Returns: trax.inputs.Inputs """ base = vocab_size - 3 # We use 0 to pad, base+1 as "+" and base+2 as "<S>". def single_example(max_length, min_length): """Generate a stream of random mini-batches.""" add_len = (min_length - 1) // 2 l1 = np.random.randint((max_length - add_len + 1) // 2) + add_len l2 = np.random.randint(max_length - l1 - 1) + 1 n1 = random_number_lower_endian(l1, base) n2 = random_number_lower_endian(l2, base) result = lower_endian_to_number(n1, base) + lower_endian_to_number( n2, base) inp = n1 + [base] + n2 tgt = number_to_lower_endian(result, base) x = [base+2] + [i+1 for i in inp] + [base+2] + [i+1 for i in tgt] weights = ([0] * (len(inp) + 2)) + ([1] * len(tgt)) return (x, weights) def batches(max_length, min_length): """Batches of examples.""" if max_length < 3: raise ValueError('Maximum length must be at least 3.') while True: res = [single_example(max_length, min_length) for _ in range(batch_size)] l = max([len(x[0]) for x in res]) xs = np.array([x[0] + [0] * (l - len(x[0])) for x in res]) ws = np.array([x[1] + [0] * (l - len(x[1])) for x in res], dtype=np.float32) xs = _pad_to_multiple_of(xs, pad_to_multiple, 1) ws = _pad_to_multiple_of(ws, pad_to_multiple, 1) yield (xs, xs, ws) return Inputs( train_stream=lambda _: batches(train_length, 3), eval_stream=lambda _: batches(eval_max_length, eval_min_length) )