LM Pretraining with Pytorch/TPU

This repo makes it easy to train language models on PyTorch/TPU. It relies on two libraries, PyTorch/XLA to run PyTorch code on TPUs, and pytorch-transformers for the language models implementation.

How to use

Create Cloud TPU

To use TPUs, all your computations happen on Google Cloud. Use the command ctpu to instantiate a TPU

ctpu up -tf-version=pytorch-0.5 -name=[lm_tpu] -tpu-size=[v3-8] -tpu-only -zone=[us-central1-a] -gcp-network=[default] -project=[my_proj] [-preemptible]

Setup environemnt

Run LM pretraining

cd /code
python -m pretrain  --pregenerated_data data/pregenerated_training_data/  --output_dir finetuned_roberta_base  --epochs 4  --bert_model  roberta-base  --train_batch_size 24

It fine tunes the roberta-base model on the sample pregenerated training data on data/pregenerated_training_data/. Each epoch will take around 15 minutes. Notice that the first few steps are usually slower than the rest because the TPU compiles the graph in the first steps, then use the cached compiled one for subsequent steps.

Pregenerate training data

The pretraining code assumes pregenerated training data, which is generated by the script pytorch_transformers_lm_finetuning/pregenerate_training_data.py. This script is adopted from the one on pytorch-transformers with some modefications. It takes as input raw text and outputs the format needed for the pretraining script. The input format is a glob of text files, each one has one sentence per line, and an empty line as document separator.

python  pytorch_transformers_lm_finetuning/pregenerate_training_data.py  --train_corpus  "data/sentences_150k.txt"  --output data/pregenerated_training_data --bert_model roberta-base  --do_whole_word_mask  --epochs_to_generate 2  --max_seq_len 512  --max_predictions_per_seq 75


Debugging and common issues

Performance Evaluation

We compared the performance of TPUs/GPUs on PyTorch/Tensorflow, and the table below summarizes the results.


The performance numbers show that:

1- TPU v3-8 (the smallest TPU which has 8 cores) is faster than 8 V100 GPUs that have the same amount of memory

2- Running PyTorch on TPUs is still 5x slower than Tensorflow. Switching to the MP interface should reduce this gap. Reaching the same level of performance will likely require some model-specific tuning.