# coding=utf-8 # Copyright 2020 The Trax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for trax.supervised.tf_inputs.""" import collections import os import gin import numpy as np from t5.data import preprocessors as t5_processors from t5.data import test_utils as t5_test_utils import tensorflow as tf import tensorflow_datasets as tfds from trax.supervised import inputs # pylint: disable=unused-import from trax.supervised import tf_inputs pkg_dir, _ = os.path.split(__file__) _TESTDATA = os.path.join(pkg_dir, 'testdata') def _test_dataset_ints(inp_lengths, tgt_lengths): """Create a test dataset of int64 tensors of given shapes.""" def generator(): for inp_len, tgt_len in zip(inp_lengths, tgt_lengths): inp = np.ones([inp_len], dtype=np.int64) tgt = np.ones([tgt_len], dtype=np.int64) yield {'inputs': inp, 'targets': tgt} types = {'inputs': tf.int64, 'targets': tf.int64} shapes = {'inputs': tf.TensorShape([None]), 'targets': tf.TensorShape([None])} return tf.data.Dataset.from_generator( generator, output_types=types, output_shapes=shapes) def _load_dataset(name, split='train'): return tfds.load( name=name, split=split, data_dir=_TESTDATA, shuffle_files=False) def _c4_dataset(split='train'): return _load_dataset('c4', split=split) def _spm_path(): return os.path.join(_TESTDATA, 'sentencepiece.model') def _t5_gin_config(): # The following pages worth of gin configuration are required because a lot # of T5 functions have `gin.REQUIRED` in code, i.e. you cannot use these # functions at all without having configured gin. noise_density = 0.15 max_input_length = 50 # What preprocessors to apply - we select a random chunk of the document if # it exceeds a certain lengths (`select_random_chunk`), the concat multiple # documents together to reduce padding (`reduce_concat_tokens`), then split # up long examples (`split_tokens`) and finally the denoising objective # (`denoise`). gin.bind_parameter('unsupervised.preprocessors', [ t5_processors.select_random_chunk, t5_processors.reduce_concat_tokens, t5_processors.split_tokens, t5_processors.denoise, ]) # select_random_chunk gin.bind_parameter('select_random_chunk.feature_key', 'targets') gin.bind_parameter('select_random_chunk.max_length', max_input_length) # reduce_concat_tokens gin.bind_parameter('random_spans_helper.extra_tokens_per_span_inputs', 1) gin.bind_parameter('random_spans_helper.extra_tokens_per_span_targets', 1) gin.bind_parameter('random_spans_helper.inputs_length', max_input_length) gin.bind_parameter('random_spans_helper.mean_noise_span_length', 3.0) gin.bind_parameter('random_spans_helper.noise_density', noise_density) # split_tokens gin.bind_parameter('split_tokens.max_tokens_per_segment', t5_processors.random_spans_tokens_length()) # denoise gin.bind_parameter('denoise.inputs_fn', t5_processors.noise_span_to_unique_sentinel) gin.bind_parameter('denoise.noise_density', noise_density) gin.bind_parameter('denoise.noise_mask_fn', t5_processors.random_spans_noise_mask) gin.bind_parameter('denoise.targets_fn', t5_processors.nonnoise_span_to_unique_sentinel) class TFInputsTest(tf.test.TestCase): def setUp(self): super().setUp() gin.clear_config() def test_c4_bare_preprocess_fn(self): dataset = _c4_dataset() example = list(tfds.as_numpy(dataset.take(1)))[0] # Targets are NOT in the example. self.assertNotIn('targets', example) self.assertIn('text', example) text = example['text'] # This should convert the dataset to an inputs/targets that are tokenized. dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) example = list(tfds.as_numpy(dataset.take(1)))[0] # Earlier text is now stored in targets_plaintext self.assertIn('targets_plaintext', example) self.assertEqual(example['targets_plaintext'], text) # Targets are now tokenized. self.assertIn('targets', example) self.assertIsInstance(example['targets'], np.ndarray) self.assertEqual(example['targets'].dtype, np.int64) self.assertGreater(len(example['targets']), 0) self.assertEqual(example['targets'][-1], 1) # we add EOS at the end. # Inputs exist but is empty because t5 preprocessors' unsupervised wasn't # gin configured with any. self.assertIn('inputs', example) self.assertEqual(len(example['inputs']), 0) def test_c4_preprocess(self): def load_c4_dataset(split='train'): dataset = _c4_dataset(split=split) return dataset.map(lambda example: (example, example['text'])) def examine_processed_dataset(proc_dataset): count = 0 lengths = [] for example in tfds.as_numpy(proc_dataset): count += 1 ex = example[0] # Targets are in the example. self.assertIn('targets', ex) self.assertEqual(ex['targets'].dtype, np.int64) lengths.append(len(ex['targets'])) return count, lengths unfiltered_count = 0 for example in tfds.as_numpy(load_c4_dataset()): unfiltered_count += 1 # Targets are NOT in the example. self.assertNotIn('targets', example[0]) proc_dataset = tf_inputs.c4_preprocess(load_c4_dataset(), False, 2048) # `examine_processed_dataset` has some asserts in it. proc_count, char_lengths = examine_processed_dataset(proc_dataset) # Both the original and filtered datasets have examples. self.assertGreater(unfiltered_count, 0) self.assertGreater(proc_count, 0) # Because we filter out some entries on length. self.assertLess(proc_count, unfiltered_count) # Preprocess using the sentencepiece model in testdata. spc_proc_dataset = tf_inputs.c4_preprocess( load_c4_dataset(), False, 2048, tokenization='spc', spm_path=_spm_path()) spc_proc_count, spc_lengths = examine_processed_dataset(spc_proc_dataset) # spc shortens the target sequence a lot, should be almost equal to # unfiltered self.assertLessEqual(proc_count, spc_proc_count) self.assertEqual(unfiltered_count, spc_proc_count) # Assert all spc_lengths are lesser than their char counterparts. for spc_len, char_len in zip(spc_lengths, char_lengths): self.assertLessEqual(spc_len, char_len) def test_c4(self): gin.bind_parameter('c4_preprocess.max_target_length', 2048) gin.bind_parameter('c4_preprocess.tokenization', 'spc') gin.bind_parameter('c4_preprocess.spm_path', _spm_path()) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='targets', target_name='text', preprocess_fn=tf_inputs.c4_preprocess) def test_c4_bare_preprocess_fn_denoising_objective(self): _t5_gin_config() dataset = _c4_dataset() dataset = tf_inputs.c4_bare_preprocess_fn(dataset, spm_path=_spm_path()) example = list(tfds.as_numpy(dataset.take(1)))[0] # Assertions now. self.assertIn('targets', example) targets = example['targets'] self.assertIsInstance(targets, np.ndarray) self.assertEqual(targets.dtype, np.int64) self.assertGreater(len(targets), 0) self.assertIn('inputs', example) _inputs = example['inputs'] # pylint: disable=invalid-name self.assertIsInstance(_inputs, np.ndarray) self.assertEqual(_inputs.dtype, np.int64) self.assertGreater(len(_inputs), 0) # WHP inputs will have the bulk of the text. self.assertGreater(len(_inputs), len(targets)) # WHP there will be two sentinel tokens in the inputs and targets. inputs_counter = collections.Counter(_inputs.tolist()) targets_counter = collections.Counter(targets.tolist()) self.assertEqual(1, inputs_counter[31999]) self.assertEqual(1, inputs_counter[31998]) self.assertEqual(1, targets_counter[31999]) self.assertEqual(1, targets_counter[31998]) def test_c4_pretrain(self): _t5_gin_config() gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('batcher.batch_size_per_device', 8) gin.bind_parameter('batcher.eval_batch_size', 8) gin.bind_parameter('batcher.max_eval_length', 50) gin.bind_parameter('batcher.buckets', ([51], [8, 1])) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn) def test_generic_text_dataset_preprocess_fn(self): dataset = _load_dataset('squad') example, = tfds.as_numpy(dataset.take(1)) self.assertNotIn('inputs', example) self.assertNotIn('targets', example) proc_dataset = tf_inputs.generic_text_dataset_preprocess_fn( dataset, spm_path=_spm_path(), text_preprocess_fns=[lambda ds, training: t5_processors.squad(ds)], copy_plaintext=True, debug_print_examples=True, debug_print_examples_rate=1.0) proc_example, = tfds.as_numpy(proc_dataset.take(1)) self.assertIn('inputs', proc_example) self.assertIn('targets', proc_example) self.assertEqual(proc_example['inputs'].dtype, np.int64) self.assertEqual(proc_example['targets'].dtype, np.int64) # TODO(afrozm): Why does this test take so much time? def test_inputs_using_generic_text_dataset_preprocess_fn(self): gin.bind_parameter( 'generic_text_dataset_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter( 'generic_text_dataset_preprocess_fn.text_preprocess_fns', [lambda ds, training: t5_processors.squad(ds)]) # Just make sure this doesn't throw. def data_streams(): return tf_inputs.data_streams( 'squad', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.generic_text_dataset_preprocess_fn, shuffle_buffer_size=1) n_devices = 3 squad_inputs = inputs.batcher( data_streams=data_streams, max_eval_length=512, buckets=([513,], [n_devices, n_devices]) ) eval_stream = squad_inputs.eval_stream(n_devices) inps, tgts, _ = next(eval_stream) # We can only assert that the batch dim gets divided by n_devices. self.assertEqual(inps.shape[0] % n_devices, 0) self.assertEqual(tgts.shape[0] % n_devices, 0) def test_filter_dataset_on_len(self): # {1, 2}, {2, 4}, {3, 6} ... {10, 20} ds = _test_dataset_ints(range(1, 11), range(2, 21, 2)) ds1 = tf_inputs.filter_dataset_on_len( ds, True, {'inputs': [4, 8], 'targets': [14, 20]}) # Only {7, 14} and {8, 16} satisfy this. self.assertLen(list(ds1.as_numpy_iterator()), 2) ds2 = tf_inputs.filter_dataset_on_len( ds, False, len_map={'inputs': [4, 8], 'targets': [14, 20]}, filter_on_eval=False) # This is eval and we aren't supposed to filter it. self.assertLen(list(ds2.as_numpy_iterator()), 10) ds3 = tf_inputs.filter_dataset_on_len( ds, False, len_map={'inputs': [4, 8], 'targets': [14, 20]}, filter_on_eval=True) # This is eval and we are asked to filter it. self.assertLen(list(ds3.as_numpy_iterator()), 2) def test_truncate_dataset_on_len(self): ds = _test_dataset_ints([5, 6, 7], [8, 9, 10]) ds1 = tf_inputs.truncate_dataset_on_len( ds, True, len_map={'inputs': 6, 'targets': 4}) expected_ds = _test_dataset_ints([5, 6, 6], [4, 4, 4]) # training, should filter. t5_test_utils.assert_dataset(ds1, list(expected_ds.as_numpy_iterator())) # not Training, shouldn't filter. ds2 = tf_inputs.truncate_dataset_on_len( ds, False, len_map={'inputs': 6, 'targets': 4}) t5_test_utils.assert_dataset(ds2, list(ds.as_numpy_iterator())) # not Training, but asked to filter, should filter. ds3 = tf_inputs.truncate_dataset_on_len( ds, False, len_map={'inputs': 6, 'targets': 4}, truncate_on_eval=True) t5_test_utils.assert_dataset(ds3, list(expected_ds.as_numpy_iterator())) def test_get_t5_preprocessor_by_name(self): gin.clear_config() gin.parse_config(""" get_t5_preprocessor_by_name.name = 'rekey' get_t5_preprocessor_by_name.fn_kwargs = {'key_map': {'inputs': 'other', 'targets': 'text'}} """) prep_rekey = tf_inputs.get_t5_preprocessor_by_name() og_dataset = tf.data.Dataset.from_tensors({ 'text': 'That is good.', 'other': 'That is bad.'}) training = True dataset = prep_rekey(og_dataset, training) t5_test_utils.assert_dataset( dataset, {'inputs': 'That is bad.', 'targets': 'That is good.'}) def test_pad_dataset_to_length(self): ds = _test_dataset_ints([5, 6, 7], [6, 7, 8]) ds1 = tf_inputs.pad_dataset_to_length( ds, True, len_map={'inputs': 7, 'targets': 10}) expected_ds = [ { 'inputs': np.array([1, 1, 1, 1, 1, 0, 0], dtype=np.int64), 'targets': np.array([1, 1, 1, 1, 1, 1, 0, 0, 0, 0], dtype=np.int64), }, { 'inputs': np.array([1, 1, 1, 1, 1, 1, 0], dtype=np.int64), 'targets': np.array([1, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64), }, { 'inputs': np.array([1, 1, 1, 1, 1, 1, 1], dtype=np.int64), 'targets': np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0], dtype=np.int64), }, ] t5_test_utils.assert_dataset(ds1, expected_ds) def test_lm_token_preprocessing(self): ds = _test_dataset_ints([1, 2, 3], [3, 2, 1]) ds1 = tf_inputs.lm_token_preprocessing(ds, True) # pylint: disable=bad-whitespace expected_ds = [ { 'inputs': np.array([1, 0, 1, 1, 1], dtype=np.int64), 'targets': np.array([1, 0, 1, 1, 1], dtype=np.int64), 'mask': np.array([0, 0, 1, 1, 1], dtype=np.int64), }, { 'inputs': np.array([1, 1, 0, 1, 1], dtype=np.int64), 'targets': np.array([1, 1, 0, 1, 1], dtype=np.int64), 'mask': np.array([0, 0, 0, 1, 1], dtype=np.int64), }, { 'inputs': np.array([1, 1, 1, 0, 1], dtype=np.int64), 'targets': np.array([1, 1, 1, 0, 1], dtype=np.int64), 'mask': np.array([0, 0, 0, 0, 1], dtype=np.int64), }, ] # pylint: enable=bad-whitespace t5_test_utils.assert_dataset(ds1, expected_ds) if __name__ == '__main__': tf.test.main()