import os import xml.etree.ElementTree as ET import glob import io from .. import data class TranslationDataset(data.Dataset): """Defines a dataset for machine translation.""" @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.src), len(ex.trg)) def __init__(self, path, exts, fields, **kwargs): """Create a TranslationDataset given paths and fields. Arguments: path: Common prefix of paths to the data files for both languages. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. Remaining keyword arguments: Passed to the constructor of data.Dataset. """ if not isinstance(fields[0], (tuple, list)): fields = [('src', fields[0]), ('trg', fields[1])] src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) examples = [] with open(src_path) as src_file, open(trg_path) as trg_file: for src_line, trg_line in zip(src_file, trg_file): src_line, trg_line = src_line.strip(), trg_line.strip() if src_line != '' and trg_line != '': examples.append(data.Example.fromlist( [src_line, trg_line], fields)) super(TranslationDataset, self).__init__(examples, fields, **kwargs) @classmethod def splits(cls, exts, fields, root='.data', train='train', validation='val', test='test', **kwargs): """Create dataset objects for splits of a TranslationDataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), exts, fields, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), exts, fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), exts, fields, **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None) class Multi30k(TranslationDataset): """The small-dataset WMT 2016 multimodal task, also known as Flickr30k""" urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz', 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz', 'http://www.quest.dcs.shef.ac.uk/' 'wmt17_files_mmt/mmt_task1_test2016.tar.gz'] name = 'multi30k' dirname = '' @classmethod def splits(cls, exts, fields, root='.data', train='train', validation='val', test='test2016', **kwargs): """Create dataset objects for splits of the Multi30k dataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ return super(Multi30k, cls).splits( exts, fields, root, train, validation, test, **kwargs) class IWSLT(TranslationDataset): """The IWSLT 2016 TED talk translation task""" base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz' name = 'iwslt' base_dirname = '{}-{}' @classmethod def splits(cls, exts, fields, root='.data', train='train', validation='IWSLT16.TED.tst2013', test='IWSLT16.TED.tst2014', **kwargs): """Create dataset objects for splits of the IWSLT dataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extension to path for each language. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. test: The prefix of the test data. Default: 'test'. Remaining keyword arguments: Passed to the splits method of Dataset. """ cls.dirname = cls.base_dirname.format(exts[0][1:], exts[1][1:]) cls.urls = [cls.base_url.format(exts[0][1:], exts[1][1:], cls.dirname)] check = os.path.join(root, cls.name, cls.dirname) path = cls.download(root, check=check) if train is not None: train = '.'.join([train, cls.dirname]) if validation is not None: validation = '.'.join([validation, cls.dirname]) if test is not None: test = '.'.join([test, cls.dirname]) if not os.path.exists(os.path.join(path, '.'.join(['train', cls.dirname])) + exts[0]): cls.clean(path) train_data = None if train is None else cls( os.path.join(path, train), exts, fields, **kwargs) val_data = None if validation is None else cls( os.path.join(path, validation), exts, fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), exts, fields, **kwargs) return tuple(d for d in (train_data, val_data, test_data) if d is not None) @staticmethod def clean(path): for f_xml in glob.iglob(os.path.join(path, '*.xml')): print(f_xml) f_txt = os.path.splitext(f_xml)[0] with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt: root = ET.parse(f_xml).getroot()[0] for doc in root.findall('doc'): for e in doc.findall('seg'): fd_txt.write(e.text.strip() + '\n') xml_tags = ['<url', '<keywords', '<talkid', '<description', '<reviewer', '<translator', '<title', '<speaker'] for f_orig in glob.iglob(os.path.join(path, 'train.tags*')): print(f_orig) f_txt = f_orig.replace('.tags', '') with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt, \ io.open(f_orig, mode='r', encoding='utf-8') as fd_orig: for l in fd_orig: if not any(tag in l for tag in xml_tags): fd_txt.write(l.strip() + '\n') class WMT14(TranslationDataset): """The WMT 2014 English-German dataset, as preprocessed by Google Brain. Though this download contains test sets from 2015 and 2016, the train set differs slightly from WMT 2015 and 2016 and significantly from WMT 2017.""" urls = [('https://drive.google.com/uc?export=download&' 'id=0B_bZck-ksdkpM25jRUN2X2UxMm8', 'wmt16_en_de.tar.gz')] name = 'wmt14' dirname = '' @classmethod def splits(cls, exts, fields, root='.data', train='train.tok.clean.bpe.32000', validation='newstest2013.tok.bpe.32000', test='newstest2014.tok.bpe.32000', **kwargs): """Create dataset objects for splits of the WMT 2014 dataset. Arguments: root: Root dataset storage directory. Default is '.data'. exts: A tuple containing the extensions for each language. Must be either ('.en', '.de') or the reverse. fields: A tuple containing the fields that will be used for data in each language. train: The prefix of the train data. Default: 'train.tok.clean.bpe.32000'. validation: The prefix of the validation data. Default: 'newstest2013.tok.bpe.32000'. test: The prefix of the test data. Default: 'newstest2014.tok.bpe.32000'. Remaining keyword arguments: Passed to the splits method of Dataset. """ return super(WMT14, cls).splits( exts, fields, root, train, validation, test, **kwargs)