avatarJacob Parnell

Summary

The provided content outlines a method for fine-tuning Transformer models for document summarization tasks using PyTorch Lightning and HuggingFace libraries.

Abstract

The article delves into the process of implementing machine learning models for natural language processing tasks, specifically document summarization. It emphasizes the use of PyTorch Lightning for structuring the training process and HuggingFace for accessing pre-trained models. The author guides readers through setting up the codebase with three separate scripts: main.py, summarisation_lightning_model.py, and summarisation_dataset.py. These scripts handle the main execution, the PyTorch Lightning model class, and dataset management, respectively. The summarisation_lightning_model.py script is detailed with essential functions for training, validation, and testing, as well as additional functions for further tuning. The summarisation_dataset.py script abstracts data handling to keep the main code clean. The article also provides a configuration for training arguments and instructions for initializing the model, logger, and checkpoint callback. Finally, it concludes with commands to train and test the model and a method for performing inference with the trained model.

Opinions

  • The author suggests that combining PyTorch Lightning and HuggingFace simplifies the implementation of complex machine learning models for NLP tasks.
  • It is implied that the use of separate scripts for different components of the model (e.g., main execution, model class, dataset management) is a good practice for maintaining organized and readable code.
  • The article posits that the provided code structure and training arguments have been effective for the BART-base model, indicating that these settings could serve as a useful starting point for practitioners.
  • The inclusion of a simple inference step at the end of the training process is presented as a practical way to demonstrate the effectiveness of the trained model.
  • The author encourages readers to explore the full codebase available on GitHub for a more comprehensive understanding of the tutorial.

Tune Transformers using PyTorch Lightning and HuggingFace

PyTorch Lightning
HuggingFace

Many tasks in Natural Language Processing (NLP) have become heavily researched in recent years, mainly due to their widespread application to many industry tasks (e.g. document summarisation, entity tagging). With the boom of large pre-trained machine learning models and the complexity behind programming (particularly if you are new to the field), it can become difficult to combine these two in an effective way. Luckily, PyTorch Lightning and HuggingFace make it easy to implement machine learning models for an array of tasks.

Let’s walk through an example for document summarisation.

First, we need to set up some code and ensure we have the right packages installed. The easiest way to interact with PyTorch Lightning is to set up three separate scripts to facilitate tuning the Transformers model:

  • main.py: main script to run the code
  • summarisation_lightning_model.py: script to host the PyTorch Lightning class to create and train the model
  • summarisation_dataset.py: script to handle loading a dataset from HuggingFace

Hosting the Model in PyTorch Lightning

The summarisation_lightning_model.py script uses the base PyTorch Lightning class which operates on 5 basic functions (more functions can be added), which you can modify to handle different processes you require:

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from transformers.optimization import get_linear_schedule_with_warmup, Adafactor
from rouge_score import rouge_scorer
import pytorch_lightning as pl
from summarisation_dataset import SummarizationDataset
from torch.nn.parallel import DistributedDataParallel


class LmForSummarisation(pl.LightningModule):

    def __init__(self, params):
        super().__init__()
        self.args = params
        self.hparams['params'] = params
        self.tokenizer = AutoTokenizer.from_pretrained(self.args['tokenizer'], use_fast=True)

        self.model = AutoModelForSeq2SeqLM.from_pretrained(self.args['model_path'], config=config)
        self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None
        self.ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)

    def forward(self, input_ids, output_ids):
        attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
        attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
        decoder_input_ids = output_ids[:, :-1]
        decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id)
        labels = output_ids[:, 1:].clone()
        outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                use_cache=False,)
        lm_logits = outputs[0]
        
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        loss = self.ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
        return [loss]

    def training_step(self, batch, batch_nb):
        output = self.forward(*batch)
        loss = output[0]
        lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr']
        tensorboard_logs = {'train_loss': loss, 'lr': lr,
                            'input_size': batch[0].numel(),
                            'output_size': batch[1].numel(),
                            'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        for p in self.model.parameters():
            p.requires_grad = False

        outputs = self.forward(*batch)
        vloss = outputs[0]
        input_ids, output_ids = batch
        input_ids, attention_mask = self._prepare_input(input_ids)
        generated_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
                                            use_cache=True, max_length=self.args['max_output_len'],
                                            num_beams=1)
        generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
        gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True)
        scorer = rouge_scorer.RougeScorer(rouge_types=['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)
        rouge1 = rouge2 = rougel = 0.0
        for ref, pred in zip(gold_str, generated_str):
            score = scorer.score(ref, pred)
            rouge1 += score['rouge1'].fmeasure
            rouge2 += score['rouge2'].fmeasure
            rougel += score['rougeL'].fmeasure
        rouge1 /= len(generated_str)
        rouge2 /= len(generated_str)
        rougel /= len(generated_str)

        return {'vloss': vloss,
                'rouge1': vloss.new_zeros(1) + rouge1,
                'rouge2': vloss.new_zeros(1) + rouge2,
                'rougeL': vloss.new_zeros(1) + rougel}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args['lr'])
        num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1
        num_steps = self.args['dataset_size'] * self.args['epochs'] / num_gpus / self.args['grad_accum'] / \
                    self.args['batch_size']
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.args['warmup'], num_training_steps=num_steps
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

There are also several additional functions you can include in this class to help facilitate tuning the model:

# Additional useful functions
def validation_epoch_end(self, outputs):
    for p in self.model.parameters():
        p.requires_grad = True

    names = ['vloss', 'rouge1', 'rouge2', 'rougeL']
    metrics = []
    for name in names:
        metric = torch.stack([x[name] for x in outputs]).mean()
        if self.trainer.accelerator_connector.use_ddp:
            torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
            metric /= self.trainer.world_size
        metrics.append(metric)
    logs = dict(zip(*[names, metrics]))
    self.log("validation_loss", logs['vloss'], prog_bar=True)
    print(logs)
    return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs}

def test_step(self, batch, batch_nb):
    return self.validation_step(batch, batch_nb)

def test_epoch_end(self, outputs):
    result = self.validation_epoch_end(outputs)
    print(result)

def _get_dataloader(self, current_dataloader, split_name, is_train):
    if current_dataloader is not None:
        return current_dataloader
    dataset = SummarizationDataset(hf_dataset=self.hf_datasets[split_name], tokenizer=self.tokenizer,
                                   max_input_len=self.args['max_input_len'],
                                   max_output_len=self.args['max_output_len'])
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if \
        self.trainer.accelerator_connector.use_ddp else None
    return DataLoader(dataset, batch_size=self.args['batch_size'], shuffle=(sampler is None),
                      num_workers=self.args['num_workers'], sampler=sampler,
                      collate_fn=SummarizationDataset.collate_fn)

def train_dataloader(self):
    self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True)
    return self.train_dataloader_object

def val_dataloader(self):
    self.val_dataloader_object = self._get_dataloader(self.val_dataloader_object, 'validation', is_train=False)
    return self.val_dataloader_object

def test_dataloader(self):
    self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False)
    return self.test_dataloader_object

def configure_ddp(self, model, device_ids):
    model = DistributedDataParallel(
        model,
        device_ids=device_ids,
        find_unused_parameters=False
    )
    return model

Creating a dataset loader

Writing the summarisation_dataset.py script separately is useful as it abstracts the need to capture handling the data within the main PyTorch Lightning class (de-clutter main code). It will usually contain:

import torch
from torch.utils.data import DataLoader, Dataset


class SummarizationDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_input_len, max_output_len):
        self.hf_dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_output_len = max_output_len

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        entry = self.hf_dataset[idx]
        input_ids = self.tokenizer.encode(entry['document'], truncation=True, max_length=self.max_input_len)
        output_ids = self.tokenizer.encode(entry['summary'], truncation=True, max_length=self.max_output_len)

        if self.tokenizer.bos_token_id is None:  # pegasus
            output_ids = [self.tokenizer.pad_token_id] + output_ids
        return torch.tensor(input_ids), torch.tensor(output_ids)

    @staticmethod
    def collate_fn(batch):
        # A hack to know if this is bart or pegasus. DDP doesn't like global variables nor class-level memebr variables
        if batch[0][0][-1].item() == 2:
            pad_token_id = 1  # AutoTokenizer.from_pretrained('facebook/bart-base').pad_token_id
        elif batch[0][0][-1].item() == 1:
            pad_token_id = 0  # AutoTokenizer.from_pretrained('google/pegasus-large').pad_token_id
        else:
            assert False

        input_ids, output_ids = list(zip(*batch))
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
        output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id)
        return input_ids, output_ids

Tuning the model

Now that you have the basic structure to run the code, you can create a main.py script to tie everything together. Start by importing packages:

import os
os.getcwd()
import random
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import nlp  # to load dataset
from summarisation_lightning_model import LmForSummarisation

Next we need to define arguments to help train the model. Play with these values to see what works best for your Transformers model (these have worked well for BART-base):

args ={
    'max_input_len': 512,
    'max_output_len': 128,
    'save_dir': 'output_models',
    'tokenizer': 'facebook/bart-base',
    'model_path': 'facebook/bart-base',
    'epochs': 5,
    'batch_size': 8,
    'grad_accum': 1,
    'lr': 0.00003,
    'warmup': 1000,
    'gpus': 1,
    'precision': 16,
    'cache_dir': 'dataset_cache',
    'attention_dropout': 0.1,
    'debug': False,
    'num_workers': 0
}

Initialise the Lightning module to prepare the model for training. For the dataset step, if you have your own dataset you wish to train, please review the following link here. We are using a multi-document summarisation dataset called Multi News:

# Initialize with a seed
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Define PyTorch Lightning model
model = LmForSummarisation(args)
# Include datasets
model.hf_datasets = nlp.load_dataset('multi_news', cache_dir=args['cache_dir'])

# Dataset size - needed to compute number of steps for the lr scheduler
args['dataset_size'] = model.hf_datasets['train'].num_rows + model.hf_datasets['validation'].num_rows

# Define logger
logger = TestTubeLogger(
    save_dir=args['save_dir'],
    name='training',
    version=0  # always use version=0
)

# Define checkpoint saver
checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(args['save_dir'], "training", "checkpoints"),
    filename='check-{epoch:02d}-{validation_loss:.2f}',
    save_top_k=1,
    verbose=True,
    monitor='validation_loss',
    mode='min',
    period=1
)

print(args)

# Define lightning trainer
trainer = pl.Trainer(gpus=args['gpus'], distributed_backend='dp' if torch.cuda.is_available() else None,
                     track_grad_norm=-1,
                     max_epochs=args['epochs'],
                     max_steps=None,
                     replace_sampler_ddp=False,
                     accumulate_grad_batches=args['grad_accum'],
                     gradient_clip_val=1.0,
                     val_check_interval=1.0,
                     num_sanity_val_steps=2,
                     check_val_every_n_epoch=1,
                     logger=logger,
                     callbacks=checkpoint_callback,
                     progress_bar_refresh_rate=10,
                     precision=args['precision'],
                     amp_backend='native', amp_level='O2'
                     )

Train and test the model

To train the model, simply run:

trainer.fit(model)
trainer.test(model)

After training, you will have the best model saved in your save directory. To use it effectively, we can create a simple inference step and place it in our summarisation_lightning_model.py script to show the effectiveness of the trained model:

def summarise_example(self, input_document):
    # Tokenize the document
    input_ids = self.tokenizer.encode(input_document, truncation=True, max_length=self.args['max_input_len'])
    input_ids = torch.tensor(input_ids)
    # Generate attention mask - similar to prepare_input() and collate_fn()
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
    attention_mask[input_ids == self.tokenizer.pad_token_id] = 0

    doc_ids = torch.nn.utils.rnn.pad_sequence([input_ids], batch_first=True, padding_value=1)
    doc_attention_mask = torch.nn.utils.rnn.pad_sequence([torch.tensor(attention_mask)],
                                                         batch_first=True, padding_value=0)

    generated_ids = self.model.generate(input_ids=doc_ids, attention_mask=doc_attention_mask,
                                        use_cache=True, max_length=self.args['max_output_len'],
                                        num_beams=1)
    # Decode to string
    generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
    return generated_str

Once you have identified the checkpoint path of the trained model, you can load the model directly using load_from_checkpoint:

# Define PyTorch Lightning model
bart_model = LmForSummarisation.load_from_checkpoint("path_to_trained_model.ckpt")
bart_summary = bart_model.summarise_example(document)

If you liked the content that was provided above, please check out the rest of the code at https://github.com/ijauregiCMCRC/ALTA2021_tutorial/tree/main/summarisation.

NLP
Hugging Face
Pytorch Lightning
Transformers
Text Summarization
Recommended from ReadMedium