How To Train a Seq2Seq Summarization Model Using “BERT” as Both Encoder and Decoder!! (BERT2BERT)
BERT is a well-known and powerful pre-trained “encoder” model. Let’s see how we can use it as a “decoder” to form an encoder-decoder architecture.
The Transformer architecture consists of two main building blocks — encoder and decoder components — which we stack on top of each other to form a seq2seq model. (You can read more about it in my previous story) It is generally hard to train a transformer-based model from scratch since it needs both large datasets and high GPU memory. So, there are numerous pre-trained models with different objectives in mind.
Firstly, the encoder models (e.g., BERT, RoBERTa, FNet,…) that learn how to create a fix-sized representation from the text they read. This representation can be used to train networks for classification, translation, summarization, etc. Secondly, the decoder-based models (like the GPT family) with generation capability. It is possible by adding a linear layer on top (also known as the “language model head”), which enables them to predict the next token. Lastly, the encoder-decoder models (BART, Pegasus, MASS, …) with the ability to condition the decoder’s output based on the encoder’s representation. It can be used for tasks such as summarization and translation. It is done by having a cross-attention connection from the encoder to the decoder.
In this story, I want to show how it is possible to use an encoder-only model’s pre-trained weights to give us a head start for fine-tuning. In this example, we will train a summarization model with BERT as both encoder and decoder.
The Huggingface Library introduced a new API some time ago that makes it possible to mix and match different pre-trained models. It is highly flexible and makes our job super easy! But, let’s see the concept before jumping into the code. What should be done to make BERT (an encoder model) works in a seq2seq setting?

Keep in mind that other elements in the demonstrated networks in Fig. 2 are removed for simplicity! To make a simple comparison, each block (layer) of the encoder-only model (left) consists of a self-attention followed by a linear layer. At the same time, the encoder-decoder network (right) also has a cross-attention connection in each layer. The cross-attention layer enables the model to condition the predictions based on the input.
As it might be evident, it is impossible to use the BERT model directly as a decoder because the building blocks are not the same! In theory, it is easy to add the extra connection and set the applicable parts of the decoder using BERT’s weights. Then, we need to fine-tune the model to train these connections and the language model head weights. (Note: The language model head position is between the output and the last linear layer — it is not included in Fig. 2)
We can use Huggingface’s EncoderDecoderModel object to mix and match different pre-trained models. It will take care of adding the needed connections and weights by calling the .from_encoder_decoder_pretrained() method and specifying the encoder/decoder models. In the following example, we use BERT-base as both encoder and decoder.







