# encoding=utf-8 import itertools import numpy as np from torch.utils.data import Dataset from config import * samples_path = 'data/samples_train.json' samples = json.load(open(samples_path, 'r')) # np.random.shuffle(samples) def zeroPadding(l, fillvalue=PAD_token): return list(itertools.zip_longest(*l, fillvalue=fillvalue)) def binaryMatrix(l): m = [] for i, seq in enumerate(l): m.append([]) for token in seq: if token == PAD_token: m[i].append(0) else: m[i].append(1) return m # Returns padded input sequence tensor and lengths def inputVar(indexes_batch): lengths = torch.tensor([len(indexes) for indexes in indexes_batch]) padList = zeroPadding(indexes_batch) padVar = torch.LongTensor(padList) return padVar, lengths # Returns padded target sequence tensor, padding mask, and max target length def outputVar(indexes_batch): max_target_len = max([len(indexes) for indexes in indexes_batch]) padList = zeroPadding(indexes_batch) mask = binaryMatrix(padList) mask = torch.ByteTensor(mask) padVar = torch.LongTensor(padList) return padVar, mask, max_target_len # Returns all items for a given batch of pairs def batch2TrainData(pair_batch): pair_batch.sort(key=lambda x: len(x[0]), reverse=True) input_batch, output_batch = [], [] for pair in pair_batch: input_batch.append(pair[0]) output_batch.append(pair[1]) inp, lengths = inputVar(input_batch) output, mask, max_target_len = outputVar(output_batch) return inp, lengths, output, mask, max_target_len class TranslationDataset(Dataset): def __init__(self, split): self.split = split assert self.split in {'train', 'valid'} print('loading {} samples'.format(split)) train_count = int(len(samples) * train_split) if split == 'train': self.samples = samples[:train_count] else: self.samples = samples[train_count:] self.num_chunks = len(self.samples) // chunk_size np.random.shuffle(self.samples) print('count: ' + str(len(self.samples))) def __getitem__(self, i): start_idx = i * chunk_size pair_batch = [] for i_batch in range(chunk_size): sample = self.samples[start_idx + i_batch] pair_batch.append((sample['input'], sample['output'])) return batch2TrainData(pair_batch) def __len__(self): return self.num_chunks if __name__ == '__main__': print('loading {} samples'.format('train')) samples_path = 'data/samples_train.json' samples = json.load(open(samples_path, 'r')) pair_batch = [] for i in range(5): sample = samples[i] pair_batch.append((sample['input'], sample['output'])) # Example for validation small_batch_size = 5 batches = batch2TrainData(pair_batch) input_variable, lengths, target_variable, mask, max_target_len = batches print("input_variable:", input_variable) print("lengths:", lengths) print("target_variable:", target_variable) print("mask:", mask) print("max_target_len:", max_target_len)