""" Turn a merged corpus into tfrecord files. NOTE: You will want to do this using several processes. I did this on an AWS machine with 72 CPUs using GNU parallel as that's where I had the deduplicated RealNews dataset. """ import argparse import ujson as json from sample.encoder import get_encoder, tokenize_for_grover_training, detokenize, sliding_window, create_int_feature import random import tensorflow as tf import collections import os from tempfile import TemporaryDirectory parser = argparse.ArgumentParser(description='SCRAPE!') parser.add_argument( '-fold', dest='fold', default=0, type=int, help='which fold we are on' ) parser.add_argument( '-num_folds', dest='num_folds', default=1, type=int, help='Number of folds (corresponding to both the number of training files and the number of testing files)', ) parser.add_argument( '-seed', dest='seed', default=1337, type=int, help='which seed to use' ) parser.add_argument( '-base_fn', dest='base_fn', default='realnews_', type=str, help='We will output files that are like {base_fn}_{n}.tfrecord for n in 0, ..., 1023' ) parser.add_argument( '-input_fn', dest='input_fn', default='realnews.jsonl', type=str, help='Base filename to use. THIS MUST BE A LOCAL FILE.' ) parser.add_argument( '-max_seq_length', dest='max_seq_length', default=1024, type=int, help='Max sequence length', ) parser.add_argument( '-add_extra_articles_to_end', dest='add_extra_articles_to_end', type=bool, action='store_true', help='Whether to minimize padding by adding extra articles to the end', ) args = parser.parse_args() random.seed(args.seed + args.fold) encoder = get_encoder() class S3TFRecordWriter(object): def __init__(self, fn): self.fn = fn if fn.startswith('s3://'): from boto3.s3.transfer import TransferConfig import boto3 self.gclient = None self.s3client = boto3.client('s3', ) self.storage_dir = TemporaryDirectory() self.writer = tf.python_io.TFRecordWriter(os.path.join(self.storage_dir.name, 'temp.tfrecord')) self.bucket_name, self.file_name = self.fn.split('s3://', 1)[1].split('/', 1) elif fn.startswith('gs://'): from google.cloud import storage self.s3client = None self.gclient = storage.Client() self.storage_dir = TemporaryDirectory() self.writer = tf.python_io.TFRecordWriter(os.path.join(self.storage_dir.name, 'temp.tfrecord')) self.bucket_name, self.file_name = self.fn.split('gs://', 1)[1].split('/', 1) else: self.s3client = None self.gclient = None self.bucket_name = None self.file_name = None self.storage_dir = None self.writer = tf.python_io.TFRecordWriter(fn) def write(self, x): self.writer.write(x) def close(self): self.writer.close() if self.s3client is not None: from boto3.s3.transfer import TransferConfig config = TransferConfig(multipart_threshold=1024 * 25, max_concurrency=10, multipart_chunksize=1024 * 25, use_threads=True) self.s3client.upload_file( os.path.join(self.storage_dir.name, 'temp.tfrecord'), self.bucket_name, self.file_name, ExtraArgs={'ACL': 'public-read'}, Config=config, ) self.storage_dir.cleanup() if self.gclient is not None: bucket = self.gclient.get_bucket(self.bucket_name) blob = bucket.blob(self.file_name) blob.upload_from_filename(os.path.join(self.storage_dir.name, 'temp.tfrecord')) self.storage_dir.cleanup() def __enter__(self): # Called when entering "with" context. return self def __exit__(self, *_): # Called when exiting "with" context. # Upload shit print("CALLING CLOSE") self.close() def article_iterator(encoder, final_desired_size=1025): """ Iterate through the provided filename + tokenize""" assert os.path.exists(args.input_fn) with open(args.input_fn, 'r') as f: for l_no, l in enumerate(f): if l_no % args.num_folds == args.fold: article = json.loads(l) article['input_ids'] = tokenize_for_grover_training(encoder, article, desired_size=final_desired_size, unconditional_prob=.35) article['inst_index'] = (l_no // args.num_folds) if article['inst_index'] < 100: print('---\nINPUT{}. {}\n---\nTokens: {}\n'.format(article['inst_index'], detokenize(encoder, article['input_ids']), article['input_ids'] ), flush=True) if len(article['input_ids']) == 0: continue yield article def _stream_from_buffer(buffer, current_desired_size, pad_token=0, add_articles_to_end=False): """ Combines short articles that are in a buffer """ random.shuffle(buffer) i = 0 while i < len(buffer): article = buffer[i] if add_articles_to_end: for article2add in buffer[(i + 1):]: i += 1 article['input_ids'].append(encoder.padding) article['input_ids'].append(encoder.reset_context) article['input_ids'].extend(article2add['input_ids']) if len(article['input_ids']) >= current_desired_size: article['input_ids'] = article['input_ids'][:current_desired_size] break # print(f"YIELD FROM BUFFER {i}") # Pad to right length amount_to_pad = current_desired_size - len(article['input_ids']) article['input_ids'].extend([pad_token] * amount_to_pad) article['sub_index'] = 0 yield article i += 1 def buffered_and_sliding_window_article_iterator(encoder, current_desired_size, final_desired_size=1025): """ We apply a sliding window to fix long sequences, and use a buffer that combines short sequences.""" assert current_desired_size <= final_desired_size buffer = [] for article in article_iterator(encoder, final_desired_size=final_desired_size): amount_to_pad = current_desired_size - len(article['input_ids']) if article['split'] == 'val' or amount_to_pad <= 0: for sub_index, sub_article in enumerate(sliding_window(article, max_seq_length=current_desired_size, pad_token=encoder.padding)): sub_article['sub_index'] = sub_index # print(f"AMT2PAD < 0 YIELD-{inst_index} sliding window {sub_index}", flush=True) yield sub_article else: # Buffer time. buffer.append(article) if len(buffer) % 100 == 0: yield from _stream_from_buffer(buffer, current_desired_size=current_desired_size, pad_token=encoder.padding, add_articles_to_end=args.add_extra_articles_to_end) buffer = [] yield from _stream_from_buffer(buffer, current_desired_size=current_desired_size, pad_token=encoder.padding, add_articles_to_end=args.add_extra_articles_to_end) # OK now write the tfrecord file total_written = 0 train_file = args.base_fn + 'train{:04d}.tfrecord'.format(args.fold) val_file = args.base_fn + 'val{:04d}.tfrecord'.format(args.fold) with S3TFRecordWriter(train_file) as train_writer, S3TFRecordWriter(val_file) as val_writer: for article in buffered_and_sliding_window_article_iterator(encoder, current_desired_size=args.max_seq_length + 1, final_desired_size=max(args.max_seq_length + 1, 1025)): writer2use = train_writer if article['split'] == 'train' else val_writer assert len(article['input_ids']) == (args.max_seq_length + 1) features = collections.OrderedDict() features["input_ids"] = create_int_feature(article['input_ids']) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer2use.write(tf_example.SerializeToString()) total_written += 1 # DEBUG if article['inst_index'] < 5: print("~~~\nSubindex{}. Index {}. ARTICLE: {}\n---\nTokens: {}\n\n".format(article['sub_index'], article['inst_index'], detokenize(encoder, article['input_ids']), article['input_ids']), flush=True) if article['inst_index'] % 1000 == 0: print("{} articles, {} written".format(article['inst_index'], total_written), flush=True) print("DONE UPLOADING", flush=True)