import argparse import copy import unittest import glob import os from collections import Counter import torchtext import onmt import onmt.io import opts import preprocess parser = argparse.ArgumentParser(description='preprocess.py') opts.preprocess_opts(parser) SAVE_DATA_PREFIX = 'data/test_preprocess' default_opts = [ '-data_type', 'text', '-train_src', 'data/src-train.txt', '-train_tgt', 'data/tgt-train.txt', '-valid_src', 'data/src-val.txt', '-valid_tgt', 'data/tgt-val.txt', '-save_data', SAVE_DATA_PREFIX ] opt = parser.parse_known_args(default_opts)[0] class TestData(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestData, self).__init__(*args, **kwargs) self.opt = opt def dataset_build(self, opt): fields = onmt.io.get_fields("text", 0, 0) train_data_files = preprocess.build_save_dataset('train', fields, opt) preprocess.build_save_vocab(train_data_files, fields, opt) preprocess.build_save_dataset('valid', fields, opt) # Remove the generated *pt files. for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): os.remove(pt) def test_merge_vocab(self): va = torchtext.vocab.Vocab(Counter('abbccc')) vb = torchtext.vocab.Vocab(Counter('eeabbcccf')) merged = onmt.io.merge_vocabs([va, vb], 2) self.assertEqual(Counter({'c': 6, 'b': 4, 'a': 2, 'e': 2, 'f': 1}), merged.freqs) # 3 specicials + 2 words (since we pass 2 to merge_vocabs) self.assertEqual(5, len(merged.itos)) self.assertTrue('b' in merged.itos) def _add_test(param_setting, methodname): """ Adds a Test to TestData according to settings Args: param_setting: list of tuples of (param, setting) methodname: name of the method that gets called """ def test_method(self): if param_setting: opt = copy.deepcopy(self.opt) for param, setting in param_setting: setattr(opt, param, setting) else: opt = self.opt getattr(self, methodname)(opt) if param_setting: name = 'test_' + methodname + "_" + "_".join( str(param_setting).split()) else: name = 'test_' + methodname + '_standard' setattr(TestData, name, test_method) test_method.__name__ = name test_databuild = [[], [('src_vocab_size', 1), ('tgt_vocab_size', 1)], [('src_vocab_size', 10000), ('tgt_vocab_size', 10000)], [('src_seq_length', 1)], [('src_seq_length', 5000)], [('src_seq_length_trunc', 1)], [('src_seq_length_trunc', 5000)], [('tgt_seq_length', 1)], [('tgt_seq_length', 5000)], [('tgt_seq_length_trunc', 1)], [('tgt_seq_length_trunc', 5000)], [('shuffle', 0)], [('lower', True)], [('dynamic_dict', True)], [('share_vocab', True)], [('dynamic_dict', True), ('share_vocab', True)], ] for p in test_databuild: _add_test(p, 'dataset_build') # Test image preprocessing for p in test_databuild: p.append(('data_type', 'img')) p.append(('src_dir', '/tmp/im2text/images')) p.append(('train_src', '/tmp/im2text/src-train-head.txt')) p.append(('train_tgt', '/tmp/im2text/tgt-train-head.txt')) p.append(('valid_src', '/tmp/im2text/src-val-head.txt')) p.append(('valid_tgt', '/tmp/im2text/tgt-val-head.txt')) _add_test(p, 'dataset_build') # Test audio preprocessing for p in test_databuild: p.append(('data_type', 'audio')) p.append(('src_dir', '/tmp/speech/an4_dataset')) p.append(('train_src', '/tmp/speech/src-train-head.txt')) p.append(('train_tgt', '/tmp/speech/tgt-train-head.txt')) p.append(('valid_src', '/tmp/speech/src-val-head.txt')) p.append(('valid_tgt', '/tmp/speech/tgt-val-head.txt')) p.append(('sample_rate', 16000)) p.append(('window_size', 0.04)) p.append(('window_stride', 0.02)) p.append(('window', 'hamming')) _add_test(p, 'dataset_build')