Revisiting Skip Connections in Transformers
This article is part of a series about the Transformer skip connections and attention layers. If you haven’t read the others, refer to the introductory article here.
I recently had a chat with one of my best friends who happens to be a great Machine Learning scientist, working at a very big company (Spire, Luxembourg). Our friendship goes way back, we had our Masters at the same University, and I learned a lot from him. Long story short, we were talking about Transformers, and I realized that my friend believes that what makes Transformers good is not the Attention, It is rather the skip connections. He said they wanted to make it ’fancy’ so of course they can’t just say it is the skip connections as it was invented a long time ago. Hence, it was the Attention block that got all interests and spotlight. My friend believes that attention is good but not really needed in the case of Transformers as it is doing the same task done by Skip Connections.
Bottom line, my friend is trying to prove this by conducting some experiments on Transformers with attention and skip connections. We will publish some results in case we get some validating results, in a journal most probably.
For now, I wanted to share this here and make my own simple experiments concerning this topic. First, in my first post, I will try to study the effects of skip connection on the Transformer performance. That is simple, we will create two Transformers: one with skip connections and one without.
We will then train them and check the effects of skip connections, which I am sure it is big.
Skip Connections In Transformers
The Transformer architecture, introduced by Vaswani et al. in 2017, has revolutionized the field of natural language processing (NLP) by achieving state-of-the-art results in various tasks, including machine translation, text summarization, and question answering. One of the key factors contributing to the Transformer’s success is its unique combination of skip connections and attention mechanisms.
Skip Connections: Preserving Information Flow
Skip connections, also known as shortcut or residual connections, are a common technique used in deep learning architectures to address the vanishing gradient problem. This problem occurs when the gradients of the loss function with respect to the weights of the network become extremely small, making it difficult for the network to learn effectively.
In the Transformer, skip connections are implemented by directly adding the output of a sub-layer to its input. This allows the gradient to propagate directly through the network, bypassing the intervening layers. This has several beneficial effects:
Preserves Information Flow: Skip connections ensure that the original input representation is not lost as the network becomes deeper. This allows the network to retain information about the input even after multiple transformations, which is crucial for tasks like machine translation where the meaning of the input sentence needs to be preserved throughout the processing pipeline.
Addresses Vanishing Gradients: By bypassing the intermediate layers, skip connections prevent the accumulation of small error gradients, preventing the vanishing gradient problem from hindering the network’s learning process.

Experiment 1: The Effects of Skip Connections on Transformers
Transformer 1: with skip connections
First, let's build a Transformer with Skip connections with TensorFlow. The code snippet below is an implementation of Transformer with skip connection and multihead attention (8 heads). We will train the model for the sentiment classification task on the IMDb dataset.
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
# Load the IMDb dataset
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.imdb.load_data()
# Convert the indices back to words
word_index = tf.keras.datasets.imdb.get_word_index()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
train_texts = [' '.join([reverse_word_index.get(i - 3, '?') for i in sequence]) for sequence in train_data]
test_texts = [' '.join([reverse_word_index.get(i - 3, '?') for i in sequence]) for sequence in test_data]
# Tokenize the text data
tokenizer = Tokenizer(num_words=5000) # Considering the top 5000 words
tokenizer.fit_on_texts(train_texts)
# Convert texts to sequences and pad them to a fixed length
train_sequences = pad_sequences(tokenizer.texts_to_sequences(train_texts), maxlen=100, padding='post')
test_sequences = pad_sequences(tokenizer.texts_to_sequences(test_texts), maxlen=100, padding='post')
# Split the data into train and validation sets
train_sequences, val_sequences, train_labels, val_labels = train_test_split(train_sequences, train_labels, test_size=0.2, random_state=42)
class TransformerWithSkipConnections(models.Model):
def __init__(self, vocab_size, embedding_dim, num_heads, num_layers):
super(TransformerWithSkipConnections, self).__init__()
self.embedding = layers.Embedding(vocab_size, embedding_dim)
self.encoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
self.decoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
self.add_layer = layers.Add()
self.flatten = layers.Flatten()
self.output_layer = layers.Dense(1, activation='sigmoid') # Binary classification output
def call(self, inputs):
embedding_output = self.embedding(inputs)
encoder_output = self.encoder(query=embedding_output, value=embedding_output)
decoder_output = self.decoder(query=embedding_output, value=embedding_output)
skip_connection = self.add_layer([encoder_output, decoder_output])
flattened_output = self.flatten(skip_connection)
output = self.output_layer(flattened_output)
return output
# Create the Transformer model with skip connections
model_with_skip_connections = TransformerWithSkipConnections(vocab_size=5000, embedding_dim=128, num_heads=8, num_layers=3)
# Compile and train the model
model_with_skip_connections.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model_with_skip_connections.fit(train_sequences, train_labels, epochs=5, batch_size=16, validation_data=(val_sequences, val_labelTransformer 2: Without Skip Connections
Here, everything remains the same except the model architecture which is modified to be without skip connections:
class TransformerNoSkipConnections(models.Model):
def __init__(self, vocab_size, embedding_dim, num_heads, num_layers):
super(TransformerNoSkipConnections, self).__init__()
self.embedding = layers.Embedding(vocab_size, embedding_dim)
self.encoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
self.decoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
self.output_layer = layers.Dense(1, activation='sigmoid') # Binary classification output
def call(self, inputs):
embedding_output = self.embedding(inputs)
encoder_output = self.encoder(query=embedding_output, value=embedding_output)
decoder_output = self.decoder(query=embedding_output, value=embedding_output)
output = self.output_layer(decoder_output)
return output
# Create the Transformer model without skip connections
model_no_skip_connections = TransformerNoSkipConnections(vocab_size=5000, embedding_dim=128, num_heads=8, num_layers=3)
# Compile and train the model
model_no_skip_connections.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model_no_skip_connections.fit(train_sequences, train_labels, epochs=5, batch_size=16, validation_data=(val_sequences, val_labels))Results
we ran the codes above where both models were trained for 5 epochs and here are the results:

Even with a very small number of epochs (5), we can observe the superior performance of a Transformer built with Skip Connections over the one without.
This shows the great impact of Skip connection in general and in Transformer in particular.
Conclusion
Skip connections and attention mechanisms are both essential components of the Transformer architecture. Skip connections help to preserve information flow and address the vanishing gradient problem, while the attention mechanism allows the network to capture long-range dependencies and understand the context of the input text. These two techniques work in conjunction to enable the Transformer to achieve remarkable performance in various NLP tasks.
This is the first part of this story. we will post another story where we build Transformers with and without Attention layers and compare the results.
If you like the article and would like to support me make sure to: 📰 View more content on my medium profile 🔔 Follow Me: LinkedIn | Medium | GitHub | Facebook
📰 View more content on AI-ContentLab Blog 🚀👉 Read more related articles to this one on Medium and AI-ContentLab




