avatarAla Alam Falaki

Summary

The article discusses a method for training a sequence-to-sequence (seq2seq) summarization model using BERT as both the encoder and decoder, leveraging the Huggingface library's flexibility to combine pre-trained models.

Abstract

The Transformer architecture's encoder-decoder structure is foundational for seq2seq models, with pre-trained models like BERT, GPT, and BART being pivotal in various language tasks. The article explores the innovative approach of repurposing BERT, traditionally an encoder-only model, to function as a decoder as well, creating a BERT2BERT model for summarization tasks. This is achieved by utilizing Huggingface's EncoderDecoderModel and Seq2SeqTrainer objects, which facilitate the mixing of pre-trained models and fine-tuning processes. The author demonstrates the feasibility of this approach by providing code snippets and discussing the necessary configurations, such as setting up the tokenizer with the appropriate beginning and end tokens. The results show that the BERT2BERT model performs competitively against the state-of-the-art BART model on the CNN/DM dataset, with only a slight difference in ROUGE scores. The article concludes by encouraging experimentation with different model combinations and highlighting the efficiency of using BERT for datasets with shorter input sequences.

Opinions

  • The author believes that using BERT as both an encoder and decoder can yield a powerful summarization model, despite BERT not being initially designed for text generation.
  • It is suggested that the mix and match approach with pre-trained models can lead to exciting experiments and potentially better-suited models for specific datasets.
  • The author emphasizes the importance of the Huggingface library for its ease of use and flexibility in combining different models, which simplifies the process of creating custom seq2seq architectures.
  • The author points out that BERT's input sequence length limitation of 512 tokens can be an advantage when working with datasets that have shorter input sequences, as it leads to more efficient training and resource usage.
  • The article promotes the idea of not always defaulting to the state-of-the-art model (in this case, BART) for all problems, advocating for a more tailored approach to model selection based on dataset characteristics and resource constraints.

How To Train a Seq2Seq Summarization Model Using “BERT” as Both Encoder and Decoder!! (BERT2BERT)

BERT is a well-known and powerful pre-trained “encoder” model. Let’s see how we can use it as a “decoder” to form an encoder-decoder architecture.

Photo by Aaron Burden on Unsplash

The Transformer architecture consists of two main building blocks — encoder and decoder components — which we stack on top of each other to form a seq2seq model. (You can read more about it in my previous story) It is generally hard to train a transformer-based model from scratch since it needs both large datasets and high GPU memory. So, there are numerous pre-trained models with different objectives in mind.

Firstly, the encoder models (e.g., BERT, RoBERTa, FNet,…) that learn how to create a fix-sized representation from the text they read. This representation can be used to train networks for classification, translation, summarization, etc. Secondly, the decoder-based models (like the GPT family) with generation capability. It is possible by adding a linear layer on top (also known as the “language model head”), which enables them to predict the next token. Lastly, the encoder-decoder models (BART, Pegasus, MASS, …) with the ability to condition the decoder’s output based on the encoder’s representation. It can be used for tasks such as summarization and translation. It is done by having a cross-attention connection from the encoder to the decoder.

In this story, I want to show how it is possible to use an encoder-only model’s pre-trained weights to give us a head start for fine-tuning. In this example, we will train a summarization model with BERT as both encoder and decoder.

The Huggingface Library introduced a new API some time ago that makes it possible to mix and match different pre-trained models. It is highly flexible and makes our job super easy! But, let’s see the concept before jumping into the code. What should be done to make BERT (an encoder model) works in a seq2seq setting?

Figure 2. A high-level scheme of a (left) encoder-only model like BERT vs. (right) a network with both encoder and decoder like BART.

Keep in mind that other elements in the demonstrated networks in Fig. 2 are removed for simplicity! To make a simple comparison, each block (layer) of the encoder-only model (left) consists of a self-attention followed by a linear layer. At the same time, the encoder-decoder network (right) also has a cross-attention connection in each layer. The cross-attention layer enables the model to condition the predictions based on the input.

As it might be evident, it is impossible to use the BERT model directly as a decoder because the building blocks are not the same! In theory, it is easy to add the extra connection and set the applicable parts of the decoder using BERT’s weights. Then, we need to fine-tune the model to train these connections and the language model head weights. (Note: The language model head position is between the output and the last linear layer — it is not included in Fig. 2)

We can use Huggingface’s EncoderDecoderModel object to mix and match different pre-trained models. It will take care of adding the needed connections and weights by calling the .from_encoder_decoder_pretrained() method and specifying the encoder/decoder models. In the following example, we use BERT-base as both encoder and decoder.

Since the BERT model is not designed for text generation, we need to do some configurations. So, the next step is to set up the tokenizer and specify the beginning-of-the-sentence and end-of-the-sentence tokens to guide training and inference processes correctly. It should be defined in the model’s config and its tokenizer object.

Now we can use the Huggingface’s Seq2Seq Trainer object to fine-tune the model using the Seq2SeqTrainingArguments() arguments. There are numerous configs you can change and experiment with to get the perfect combination for your model. Note that the following values are not the optimum choices and are only used for testing! The fp16 value is one of the important ones if you don’t have GPU memory. It will use half-precision numbers, which reduces memory usage. The other useful variables to study is learning_rate , batch_size .

I am not going to go through the whole fine-tuning process since I already mentioned how to use the datasetslibrary to load the data. Here is a link to the Google Colab notebook (I’ve copied/pasted the code to my GitHub account to make sure the link stays valid), which I’ve taken from the patrickvonplaten/bert2bert_cnn_daily_mail hub checkpoint. The notebook will go over the whole fine-tuning process to train the model for summarization.

Results

We can see the BERT-to-BERT model performance that is fine-tuned on the CNN/DM dataset. I used the available checkpoint on the dataset hub using the Beam Search decoding method. The results are calculated using the ROUGE score metric.

Figure 3. The results comparing BERT-base and BART-base.

The BART model is the SOTA model in text summarization, and the BERT seq2seq network holds up pretty well! There is only a 1% difference that usually will not translate to a huge change in sentence quality.

Final Words,

The mix and match approach can result in exciting experiments. For example, it is possible to connect BERT to GPT-2 to use the BERT’s power to create a powerful representation of the texts and the GPT’s ability to generate quality sentences. It is a good practice to use different networks for your custom datasets before choosing the SOTA model for all problems. The main difference between using BERT (compared to BART) is the 512 tokens input sequence length limitation (compared to 1024). So, it makes the BERT-to-BERT model a good choice if your dataset’s input sequences are smaller. It would be more efficient to train smaller models and require fewer resources such as data and GPU memory.

I send out a weekly newsletter for NLP nerds. Consider subscribing if you like to stay up-to-date on the latest developments in Natural Language Processing. Read more and subscribe — join the cool kids club and sign up now!

Summarization
Bert
NLP
Naturallanguageprocessing
Encoder Decoder
Recommended from ReadMedium