Tune Transformers using PyTorch Lightning and HuggingFace


Many tasks in Natural Language Processing (NLP) have become heavily researched in recent years, mainly due to their widespread application to many industry tasks (e.g. document summarisation, entity tagging). With the boom of large pre-trained machine learning models and the complexity behind programming (particularly if you are new to the field), it can become difficult to combine these two in an effective way. Luckily, PyTorch Lightning and HuggingFace make it easy to implement machine learning models for an array of tasks.
Let’s walk through an example for document summarisation.
First, we need to set up some code and ensure we have the right packages installed. The easiest way to interact with PyTorch Lightning is to set up three separate scripts to facilitate tuning the Transformers model:
- main.py: main script to run the code
- summarisation_lightning_model.py: script to host the PyTorch Lightning class to create and train the model
- summarisation_dataset.py: script to handle loading a dataset from HuggingFace
Hosting the Model in PyTorch Lightning
The summarisation_lightning_model.py script uses the base PyTorch Lightning class which operates on 5 basic functions (more functions can be added), which you can modify to handle different processes you require:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from transformers.optimization import get_linear_schedule_with_warmup, Adafactor
from rouge_score import rouge_scorer
import pytorch_lightning as pl
from summarisation_dataset import SummarizationDataset
from torch.nn.parallel import DistributedDataParallel
class LmForSummarisation(pl.LightningModule):
def __init__(self, params):
super().__init__()
self.args = params
self.hparams['params'] = params
self.tokenizer = AutoTokenizer.from_pretrained(self.args['tokenizer'], use_fast=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.args['model_path'], config=config)
self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None
self.ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
def forward(self, input_ids, output_ids):
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
decoder_input_ids = output_ids[:, :-1]
decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id)
labels = output_ids[:, 1:].clone()
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
use_cache=False,)
lm_logits = outputs[0]
assert lm_logits.shape[-1] == self.model.config.vocab_size
loss = self.ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
return [loss]
def training_step(self, batch, batch_nb):
output = self.forward(*batch)
loss = output[0]
lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr']
tensorboard_logs = {'train_loss': loss, 'lr': lr,
'input_size': batch[0].numel(),
'output_size': batch[1].numel(),
'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_nb):
for p in self.model.parameters():
p.requires_grad = False
outputs = self.forward(*batch)
vloss = outputs[0]
input_ids, output_ids = batch
input_ids, attention_mask = self._prepare_input(input_ids)
generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
use_cache=True, max_length=self.args['max_output_len'],
num_beams=1)
generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True)
scorer = rouge_scorer.RougeScorer(rouge_types=['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)
rouge1 = rouge2 = rougel = 0.0
for ref, pred in zip(gold_str, generated_str):
score = scorer.score(ref, pred)
rouge1 += score['rouge1'].fmeasure
rouge2 += score['rouge2'].fmeasure
rougel += score['rougeL'].fmeasure
rouge1 /= len(generated_str)
rouge2 /= len(generated_str)
rougel /= len(generated_str)
return {'vloss': vloss,
'rouge1': vloss.new_zeros(1) + rouge1,
'rouge2': vloss.new_zeros(1) + rouge2,
'rougeL': vloss.new_zeros(1) + rougel}
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args['lr'])
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
num_steps = self.args['dataset_size'] * self.args['epochs'] / num_gpus / self.args['grad_accum'] / \
self.args['batch_size']
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=self.args['warmup'], num_training_steps=num_steps
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]There are also several additional functions you can include in this class to help facilitate tuning the model:
# Additional useful functions
def validation_epoch_end(self, outputs):
for p in self.model.parameters():
p.requires_grad = True
names = ['vloss', 'rouge1', 'rouge2', 'rougeL']
metrics = []
for name in names:
metric = torch.stack([x[name] for x in outputs]).mean()
if self.trainer.accelerator_connector.use_ddp:
torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
metric /= self.trainer.world_size
metrics.append(metric)
logs = dict(zip(*[names, metrics]))
self.log("validation_loss", logs['vloss'], prog_bar=True)
print(logs)
return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs}
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_epoch_end(self, outputs):
result = self.validation_epoch_end(outputs)
print(result)
def _get_dataloader(self, current_dataloader, split_name, is_train):
if current_dataloader is not None:
return current_dataloader
dataset = SummarizationDataset(hf_dataset=self.hf_datasets[split_name], tokenizer=self.tokenizer,
max_input_len=self.args['max_input_len'],
max_output_len=self.args['max_output_len'])
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if \
self.trainer.accelerator_connector.use_ddp else None
return DataLoader(dataset, batch_size=self.args['batch_size'], shuffle=(sampler is None),
num_workers=self.args['num_workers'], sampler=sampler,
collate_fn=SummarizationDataset.collate_fn)
def train_dataloader(self):
self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True)
return self.train_dataloader_object
def val_dataloader(self):
self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, 'validation', is_train=False)
return self.val_dataloader_object
def test_dataloader(self):
self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False)
return self.test_dataloader_object
def configure_ddp(self, model, device_ids):
model = DistributedDataParallel(
model,
device_ids=device_ids,
find_unused_parameters=False
)
return modelCreating a dataset loader
Writing the summarisation_dataset.py script separately is useful as it abstracts the need to capture handling the data within the main PyTorch Lightning class (de-clutter main code). It will usually contain:
import torch
from torch.utils.data import DataLoader, Dataset
class SummarizationDataset(Dataset):
def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
self.max_input_len = max_input_len
self.max_output_len = max_output_len
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
entry = self.hf_dataset[idx]
input_ids = self.tokenizer.encode(entry['document'], truncation=True, max_length=self.max_input_len)
output_ids = self.tokenizer.encode(entry['summary'], truncation=True, max_length=self.max_output_len)
if self.tokenizer.bos_token_id is None: # pegasus
output_ids = [self.tokenizer.pad_token_id] + output_ids
return torch.tensor(input_ids), torch.tensor(output_ids)
@staticmethod
def collate_fn(batch):
# A hack to know if this is bart or pegasus. DDP doesn't like global variables nor class-level memebr variables
if batch[0][0][-1].item() == 2:
pad_token_id = 1 # AutoTokenizer.from_pretrained('facebook/bart-base').pad_token_id
elif batch[0][0][-1].item() == 1:
pad_token_id = 0 # AutoTokenizer.from_pretrained('google/pegasus-large').pad_token_id
else:
assert False
input_ids, output_ids = list(zip(*batch))
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id)
return input_ids, output_idsTuning the model
Now that you have the basic structure to run the code, you can create a main.py script to tie everything together. Start by importing packages:
import os
os.getcwd()
import random
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import nlp # to load dataset
from summarisation_lightning_model import LmForSummarisationNext we need to define arguments to help train the model. Play with these values to see what works best for your Transformers model (these have worked well for BART-base):
args ={
'max_input_len': 512,
'max_output_len': 128,
'save_dir': 'output_models',
'tokenizer': 'facebook/bart-base',
'model_path': 'facebook/bart-base',
'epochs': 5,
'batch_size': 8,
'grad_accum': 1,
'lr': 0.00003,
'warmup': 1000,
'gpus': 1,
'precision': 16,
'cache_dir': 'dataset_cache',
'attention_dropout': 0.1,
'debug': False,
'num_workers': 0
}Initialise the Lightning module to prepare the model for training. For the dataset step, if you have your own dataset you wish to train, please review the following link here. We are using a multi-document summarisation dataset called Multi News:
# Initialize with a seed
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Define PyTorch Lightning model
model = LmForSummarisation(args)
# Include datasets
model.hf_datasets = nlp.load_dataset('multi_news', cache_dir=args['cache_dir'])
# Dataset size - needed to compute number of steps for the lr scheduler
args['dataset_size'] = model.hf_datasets['train'].num_rows + model.hf_datasets['validation'].num_rows
# Define logger
logger = TestTubeLogger(
save_dir=args['save_dir'],
name='training',
version=0 # always use version=0
)
# Define checkpoint saver
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(args['save_dir'], "training", "checkpoints"),
filename='check-{epoch:02d}-{validation_loss:.2f}',
save_top_k=1,
verbose=True,
monitor='validation_loss',
mode='min',
period=1
)
print(args)
# Define lightning trainer
trainer = pl.Trainer(gpus=args['gpus'], distributed_backend='dp' if torch.cuda.is_available() else None,
track_grad_norm=-1,
max_epochs=args['epochs'],
max_steps=None,
replace_sampler_ddp=False,
accumulate_grad_batches=args['grad_accum'],
gradient_clip_val=1.0,
val_check_interval=1.0,
num_sanity_val_steps=2,
check_val_every_n_epoch=1,
logger=logger,
callbacks=checkpoint_callback,
progress_bar_refresh_rate=10,
precision=args['precision'],
amp_backend='native', amp_level='O2'
)Train and test the model
To train the model, simply run:
trainer.fit(model)
trainer.test(model)After training, you will have the best model saved in your save directory. To use it effectively, we can create a simple inference step and place it in our summarisation_lightning_model.py script to show the effectiveness of the trained model:
def summarise_example(self, input_document):
# Tokenize the document
input_ids = self.tokenizer.encode(input_document, truncation=True, max_length=self.args['max_input_len'])
input_ids = torch.tensor(input_ids)
# Generate attention mask - similar to prepare_input() and collate_fn()
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
doc_ids = torch.nn.utils.rnn.pad_sequence([input_ids], batch_first=True, padding_value=1)
doc_attention_mask = torch.nn.utils.rnn.pad_sequence([torch.tensor(attention_mask)],
batch_first=True, padding_value=0)
generated_ids = self.model.generate(input_ids=doc_ids, attention_mask=doc_attention_mask,
use_cache=True, max_length=self.args['max_output_len'],
num_beams=1)
# Decode to string
generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
return generated_str
Once you have identified the checkpoint path of the trained model, you can load the model directly using load_from_checkpoint:
# Define PyTorch Lightning model
bart_model = LmForSummarisation.load_from_checkpoint("path_to_trained_model.ckpt")
bart_summary = bart_model.summarise_example(document)If you liked the content that was provided above, please check out the rest of the code at https://github.com/ijauregiCMCRC/ALTA2021_tutorial/tree/main/summarisation.






