# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. import numpy as np import torch from . import data_utils, FairseqDataset def collate(samples, pad_idx, eos_idx): if len(samples) == 0: return {} def merge(key): return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad=False, ) return { 'id': torch.LongTensor([s['id'] for s in samples]), 'ntokens': sum(len(s['target']) for s in samples), 'net_input': { 'src_tokens': merge('source'), }, 'target': merge('target'), } class MonolingualDataset(FairseqDataset): """A wrapper around torch.utils.data.Dataset for monolingual data.""" def __init__(self, dataset, sizes, vocab, shuffle): self.dataset = dataset self.sizes = np.array(sizes) self.vocab = vocab self.shuffle = shuffle def __getitem__(self, index): source, target = self.dataset[index] return {'id': index, 'source': source, 'target': target} def __len__(self): return len(self.dataset) def collater(self, samples): """Merge a list of samples to form a mini-batch.""" return collate(samples, self.vocab.pad(), self.vocab.eos()) def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128): assert isinstance(max_positions, float) or isinstance(max_positions, int) tgt_len = min(tgt_len, max_positions) bsz = num_tokens // tgt_len target = self.vocab.dummy_sentence(tgt_len + 1) source, target = target[:-1], target[1:] return self.collater([ {'id': i, 'source': source, 'target': target} for i in range(bsz) ]) def num_tokens(self, index): """Return an example's length (number of tokens), used for batching.""" source, target = self.dataset[index] return len(source) def ordered_indices(self): """Ordered indices for batching.""" if self.shuffle: order = [np.random.permutation(len(self))] else: order = [np.arange(len(self))] order.append(self.sizes) return np.lexsort(order) def valid_size(self, index, max_positions): """Check if an example's size is valid according to max_positions.""" assert isinstance(max_positions, float) or isinstance(max_positions, int) return self.sizes[index] <= max_positions