avatarZahra Ahmad

Summary

This context provides a step-by-step guide on how to implement a Transformer model from scratch using PyTorch, including data generation, model definition, and training using PyTorch Lightning.

Abstract

The context begins by discussing the significance of Transformer models in deep learning and their superiority over RNNs and CNNs in NLP and vision tasks. It then proceeds to generate simple input and output data for training the model, with output being random number sequences and input being the same sequences but with each element repeated twice. The data is then wrapped in a way that enables PyTorch data fetching. The model class is defined using the Transformer architecture, and a PyTorch Lightning model is used for training. The context concludes by demonstrating the decoding process and providing references for further reading.

Bullet points

  • Transformer models are a game-changing innovation in deep learning, surpassing RNNs and CNNs in NLP and vision tasks.
  • The context provides a simple, self-contained example of a Transformer model using both the encoder and decoder parts and greedy decoding at inference time.
  • The data generation process involves creating input and output data with input being the same as output but with each element repeated twice.
  • The data is wrapped in a way that enables PyTorch data fetching, and a PyTorch Lightning model is used for training.
  • The context demonstrates the decoding process and provides references for further reading.

Implementing Transformer from Scratch in Pytorch

Photo by Kelly Sikkema on Unsplash

Transformers are a game-changing innovation in deep learning.

This model architecture has superseded all variants of RNNs in NLP tasks, and is showing promise to do the same to CNNs in vision tasks.

However, the PyTorch Transformer docs make it a bit difficult to get started.

  • There is no explanation of how to do inference
  • The tutorial shows an encoder-only transformer

This notebook provides a simple, self-contained example of Transformer:

  • using both the encoder and decoder parts
  • greedy decoding at inference time

We train on a simple synthetic example, and use PyTorch-Lightning for the training loop.

!pip install pytorch_lightning
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

Data

First, we generate simple input and output data.

Output: random number sequences like [1, 5, 3]

Input: same as output, but with each element repeated twice, e.g. [1, 1, 5, 5, 3, 3]

N = 10000
S = 32  # target sequence length. input sequence will be twice as long
C = 128  # number of "classes", including 0, the "start token", and 1, the "end token"
Y = (torch.rand((N * 10, S - 2)) * (C - 2)).long() + 2  # Only generate ints in (2, 99) range
# Make sure we only have unique rows
Y = torch.tensor(np.unique(Y, axis=0)[:N])
X = torch.repeat_interleave(Y, 2, dim=1)
# Add special 0 "start" and 1 "end" tokens to beginning and end
Y = torch.cat([torch.zeros((N, 1)), Y, torch.ones((N, 1))], dim=1).long()
X = torch.cat([torch.zeros((N, 1)), X, torch.ones((N, 1))], dim=1).long()
# Look at the data
print(X, X.shape)
print(Y, Y.shape)
print(Y.min(), Y.max())

The output will be:

tensor([[  0,   2,   2,  ...,  48,  48,   1],
        [  0,   2,   2,  ..., 105, 105,   1],
        [  0,   2,   2,  ...,   6,   6,   1],
        ...,
        [  0,  14,  14,  ...,  47,  47,   1],
        [  0,  14,  14,  ..., 106, 106,   1],
        [  0,  14,  14,  ...,  85,  85,   1]]) torch.Size([10000, 62])
tensor([[  0,   2,   2,  ...,  72,  48,   1],
        [  0,   2,   2,  ...,  48, 105,   1],
        [  0,   2,   2,  ...,  65,   6,   1],
        ...,
        [  0,  14,  58,  ...,  60,  47,   1],
        [  0,  14,  59,  ...,  78, 106,   1],
        [  0,  14,  59,  ...,  29,  85,   1]]) torch.Size([10000, 32])
tensor(0) tensor(127)

Let’s move on:

# Wrap data in the simplest possible way to enable PyTorch data fetching
# https://pytorch.org/docs/stable/data.html
BATCH_SIZE = 128
TRAIN_FRAC = 0.8
dataset = list(zip(X, Y))  # This fulfills the pytorch.utils.data.Dataset interface
# Split into train and val
num_train = int(N * TRAIN_FRAC)
num_val = N - num_train
data_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)
dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)
# Sample batch
x, y = next(iter(dataloader_train))
x, y

output:

(tensor([[  0,  13,  13,  ..., 100, 100,   1],
         [  0,   6,   6,  ...,  20,  20,   1],
         [  0,   2,   2,  ...,  87,  87,   1],
         ...,
         [  0,   7,   7,  ..., 102, 102,   1],
         [  0,   2,   2,  ...,  19,  19,   1],
         [  0,   4,   4,  ..., 117, 117,   1]]),
 tensor([[  0,  13,  44,  ..., 125, 100,   1],
         [  0,   6, 106,  ...,  74,  20,   1],
         [  0,   2,  87,  ...,  81,  87,   1],
         ...,
         [  0,   7, 101,  ..., 111, 102,   1],
         [  0,   2, 111,  ...,  42,  19,   1],
         [  0,   4,  54,  ..., 120, 117,   1]]))

Model

Now we define the model class:

Output:

torch.Size([128, 62]) torch.Size([128, 32]) torch.Size([128, 128, 31])
tensor([[  0,  13,  13,  44,  44,  13,  13,  61,  61, 104, 104,   2,   2,  62,
          62,  97,  97,  98,  98,  89,  89,  97,  97,  62,  62, 119, 119,  56,
          56, 119, 119,  69,  69,  12,  12,  58,  58,  96,  96, 121, 121,  21,
          21, 109, 109,   3,   3,  73,  73,  65,  65,   2,   2,  69,  69,  84,
          84, 125, 125, 100, 100,   1]])
tensor([[  0,  75, 114, 106, 114, 114, 114, 108,  14, 114, 108,  14,  75,   0,
         106, 114, 114, 108, 114, 104, 114, 114, 114, 114, 114, 114, 114, 108,
          75,  75,   0,  11]])

Now we define lightning model to train our Transformer:

# We can see that the decoding works correctly
x, y = next(iter(dataloader_val))
print('Input:', x[:1])
pred = lit_model.model.predict(x[:1])
print('Truth/Pred:')
print(torch.cat((y[:1], pred)))

Output for predictions:

Input: tensor([[  0,   2,   2,  52,  52,  51,  51,  20,  20, 122, 122,  39,  39,  12,
          12,  11,  11,  41,  41,  23,  23,  30,  30,  13,  13,  52,  52, 106,
         106,  38,  38,  46,  46,  78,  78,  64,  64, 107, 107,  90,  90,  60,
          60,  55,  55,  61,  61,   8,   8,  59,  59,  67,  67,  83,  83,  44,
          44,  81,  81,  82,  82,   1]])
Truth/Pred:
tensor([[  0,   2,  52,  51,  20, 122,  39,  12,  11,  41,  23,  30,  13,  52,
         106,  38,  46,  78,  64, 107,  90,  60,  55,  61,   8,  59,  67,  83,
          44,  81,  82,   1],
        [  0,   2,  52,  51,  20, 122,  39,  12,  11,  41,  30,  13,  30,  52,
         106,  38,  46,  78,  64, 107,  90,  60,  55,  61,   8,  59,  67,  83,
          44,  81,  82,   1]])

That is it :)

Reference here and Google Colab

trainer.fit(lit_model, dataloader_train, dataloader_val)

Read Also

Transformers
Deep Learning
Pytorch
Python
Machine Learning
Recommended from ReadMedium