avatarAbdulkader Helwan

Summary

This article discusses the significance of skip connections in the Transformer architecture, suggesting that they may be more important than attention mechanisms, and presents an experiment comparing the performance of Transformers with and without skip connections.

Abstract

The article begins with a conversation between the author and a friend, who believes that skip connections are the key factor behind the success of Transformers, rather than the attention mechanism. The friend plans to conduct experiments to prove this claim. The author then explains the purpose of skip connections in addressing the vanishing gradient problem and preserving information flow in deep learning architectures. The article proceeds to describe an experiment where two Transformers are built, one with skip connections and one without, and trained on the IMDb dataset for sentiment classification. The results show that the Transformer with skip connections outperforms the one without, demonstrating the impact of skip connections in Transformers. The article concludes by emphasizing the importance of both skip connections and attention mechanisms in the Transformer architecture.

Bullet points

  • The author's friend believes that skip connections, not attention, are the key factor behind the success of Transformers.
  • Skip connections are used to address the vanishing gradient problem and preserve information flow in deep learning architectures.
  • An experiment is conducted to compare the performance of Transformers with and without skip connections.
  • The Transformer with skip connections outperforms the one without, demonstrating the impact of skip connections.
  • Both skip connections and attention mechanisms are essential components of the Transformer architecture.

Revisiting Skip Connections in Transformers

This article is part of a series about the Transformer skip connections and attention layers. If you haven’t read the others, refer to the introductory article here.

I recently had a chat with one of my best friends who happens to be a great Machine Learning scientist, working at a very big company (Spire, Luxembourg). Our friendship goes way back, we had our Masters at the same University, and I learned a lot from him. Long story short, we were talking about Transformers, and I realized that my friend believes that what makes Transformers good is not the Attention, It is rather the skip connections. He said they wanted to make it ’fancy’ so of course they can’t just say it is the skip connections as it was invented a long time ago. Hence, it was the Attention block that got all interests and spotlight. My friend believes that attention is good but not really needed in the case of Transformers as it is doing the same task done by Skip Connections.

Bottom line, my friend is trying to prove this by conducting some experiments on Transformers with attention and skip connections. We will publish some results in case we get some validating results, in a journal most probably.

For now, I wanted to share this here and make my own simple experiments concerning this topic. First, in my first post, I will try to study the effects of skip connection on the Transformer performance. That is simple, we will create two Transformers: one with skip connections and one without.

We will then train them and check the effects of skip connections, which I am sure it is big.

Skip Connections In Transformers

The Transformer architecture, introduced by Vaswani et al. in 2017, has revolutionized the field of natural language processing (NLP) by achieving state-of-the-art results in various tasks, including machine translation, text summarization, and question answering. One of the key factors contributing to the Transformer’s success is its unique combination of skip connections and attention mechanisms.

Skip Connections: Preserving Information Flow

Skip connections, also known as shortcut or residual connections, are a common technique used in deep learning architectures to address the vanishing gradient problem. This problem occurs when the gradients of the loss function with respect to the weights of the network become extremely small, making it difficult for the network to learn effectively.

In the Transformer, skip connections are implemented by directly adding the output of a sub-layer to its input. This allows the gradient to propagate directly through the network, bypassing the intervening layers. This has several beneficial effects:

Preserves Information Flow: Skip connections ensure that the original input representation is not lost as the network becomes deeper. This allows the network to retain information about the input even after multiple transformations, which is crucial for tasks like machine translation where the meaning of the input sentence needs to be preserved throughout the processing pipeline.

Addresses Vanishing Gradients: By bypassing the intermediate layers, skip connections prevent the accumulation of small error gradients, preventing the vanishing gradient problem from hindering the network’s learning process.

Skip Connections

Experiment 1: The Effects of Skip Connections on Transformers

Transformer 1: with skip connections

First, let's build a Transformer with Skip connections with TensorFlow. The code snippet below is an implementation of Transformer with skip connection and multihead attention (8 heads). We will train the model for the sentiment classification task on the IMDb dataset.

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

# Load the IMDb dataset
(train_data, train_labels), (test_data, test_labels) = tf.keras.datasets.imdb.load_data()

# Convert the indices back to words
word_index = tf.keras.datasets.imdb.get_word_index()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
train_texts = [' '.join([reverse_word_index.get(i - 3, '?') for i in sequence]) for sequence in train_data]
test_texts = [' '.join([reverse_word_index.get(i - 3, '?') for i in sequence]) for sequence in test_data]

# Tokenize the text data
tokenizer = Tokenizer(num_words=5000)  # Considering the top 5000 words
tokenizer.fit_on_texts(train_texts)

# Convert texts to sequences and pad them to a fixed length
train_sequences = pad_sequences(tokenizer.texts_to_sequences(train_texts), maxlen=100, padding='post')
test_sequences = pad_sequences(tokenizer.texts_to_sequences(test_texts), maxlen=100, padding='post')

# Split the data into train and validation sets
train_sequences, val_sequences, train_labels, val_labels = train_test_split(train_sequences, train_labels, test_size=0.2, random_state=42)

class TransformerWithSkipConnections(models.Model):
    def __init__(self, vocab_size, embedding_dim, num_heads, num_layers):
        super(TransformerWithSkipConnections, self).__init__()

        self.embedding = layers.Embedding(vocab_size, embedding_dim)
        self.encoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
        self.decoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
        self.add_layer = layers.Add()
        self.flatten = layers.Flatten()
        self.output_layer = layers.Dense(1, activation='sigmoid')  # Binary classification output

    def call(self, inputs):
        embedding_output = self.embedding(inputs)
        encoder_output = self.encoder(query=embedding_output, value=embedding_output)
        decoder_output = self.decoder(query=embedding_output, value=embedding_output)

        skip_connection = self.add_layer([encoder_output, decoder_output])
        flattened_output = self.flatten(skip_connection)
        output = self.output_layer(flattened_output)

        return output

# Create the Transformer model with skip connections
model_with_skip_connections = TransformerWithSkipConnections(vocab_size=5000, embedding_dim=128, num_heads=8, num_layers=3)

# Compile and train the model
model_with_skip_connections.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model_with_skip_connections.fit(train_sequences, train_labels, epochs=5, batch_size=16, validation_data=(val_sequences, val_label

Transformer 2: Without Skip Connections

Here, everything remains the same except the model architecture which is modified to be without skip connections:

class TransformerNoSkipConnections(models.Model):
    def __init__(self, vocab_size, embedding_dim, num_heads, num_layers):
        super(TransformerNoSkipConnections, self).__init__()

        self.embedding = layers.Embedding(vocab_size, embedding_dim)
        self.encoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
        self.decoder = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)
        self.output_layer = layers.Dense(1, activation='sigmoid')  # Binary classification output

    def call(self, inputs):
        embedding_output = self.embedding(inputs)
        encoder_output = self.encoder(query=embedding_output, value=embedding_output)
        decoder_output = self.decoder(query=embedding_output, value=embedding_output)

        output = self.output_layer(decoder_output)

        return output

# Create the Transformer model without skip connections
model_no_skip_connections = TransformerNoSkipConnections(vocab_size=5000, embedding_dim=128, num_heads=8, num_layers=3)

# Compile and train the model
model_no_skip_connections.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model_no_skip_connections.fit(train_sequences, train_labels, epochs=5, batch_size=16, validation_data=(val_sequences, val_labels))

Results

we ran the codes above where both models were trained for 5 epochs and here are the results:

Accuracy and Loss variations over 5 epochs for both Transformers

Even with a very small number of epochs (5), we can observe the superior performance of a Transformer built with Skip Connections over the one without.

This shows the great impact of Skip connection in general and in Transformer in particular.

Conclusion

Skip connections and attention mechanisms are both essential components of the Transformer architecture. Skip connections help to preserve information flow and address the vanishing gradient problem, while the attention mechanism allows the network to capture long-range dependencies and understand the context of the input text. These two techniques work in conjunction to enable the Transformer to achieve remarkable performance in various NLP tasks.

This is the first part of this story. we will post another story where we build Transformers with and without Attention layers and compare the results.

If you like the article and would like to support me make sure to: 📰 View more content on my medium profile 🔔 Follow Me: LinkedIn | Medium | GitHub | Facebook

📰 View more content on AI-ContentLab Blog 🚀👉 Read more related articles to this one on Medium and AI-ContentLab

Transformers
NLP
Skip Connection
Multihead Attention
TensorFlow
Recommended from ReadMedium