Implementing Transformer from Scratch in Pytorch
Transformers are a game-changing innovation in deep learning.
This model architecture has superseded all variants of RNNs in NLP tasks, and is showing promise to do the same to CNNs in vision tasks.
However, the PyTorch Transformer docs make it a bit difficult to get started.
- There is no explanation of how to do inference
- The tutorial shows an encoder-only transformer
This notebook provides a simple, self-contained example of Transformer:
- using both the encoder and decoder parts
- greedy decoding at inference time
We train on a simple synthetic example, and use PyTorch-Lightning for the training loop.
!pip install pytorch_lightning
import mathimport numpy as npimport pytorch_lightning as plimport torchimport torch.nn as nnimport torch.nn.functional as FData
First, we generate simple input and output data.
Output: random number sequences like [1, 5, 3]
Input: same as output, but with each element repeated twice, e.g. [1, 1, 5, 5, 3, 3]
N = 10000S = 32 # target sequence length. input sequence will be twice as longC = 128 # number of "classes", including 0, the "start token", and 1, the "end token"Y = (torch.rand((N * 10, S - 2)) * (C - 2)).long() + 2 # Only generate ints in (2, 99) range# Make sure we only have unique rowsY = torch.tensor(np.unique(Y, axis=0)[:N])X = torch.repeat_interleave(Y, 2, dim=1)# Add special 0 "start" and 1 "end" tokens to beginning and endY = torch.cat([torch.zeros((N, 1)), Y, torch.ones((N, 1))], dim=1).long()X = torch.cat([torch.zeros((N, 1)), X, torch.ones((N, 1))], dim=1).long()# Look at the dataprint(X, X.shape)print(Y, Y.shape)print(Y.min(), Y.max())The output will be:
tensor([[ 0, 2, 2, ..., 48, 48, 1],
[ 0, 2, 2, ..., 105, 105, 1],
[ 0, 2, 2, ..., 6, 6, 1],
...,
[ 0, 14, 14, ..., 47, 47, 1],
[ 0, 14, 14, ..., 106, 106, 1],
[ 0, 14, 14, ..., 85, 85, 1]]) torch.Size([10000, 62])
tensor([[ 0, 2, 2, ..., 72, 48, 1],
[ 0, 2, 2, ..., 48, 105, 1],
[ 0, 2, 2, ..., 65, 6, 1],
...,
[ 0, 14, 58, ..., 60, 47, 1],
[ 0, 14, 59, ..., 78, 106, 1],
[ 0, 14, 59, ..., 29, 85, 1]]) torch.Size([10000, 32])
tensor(0) tensor(127)Let’s move on:
# Wrap data in the simplest possible way to enable PyTorch data fetching# https://pytorch.org/docs/stable/data.htmlBATCH_SIZE = 128TRAIN_FRAC = 0.8dataset = list(zip(X, Y)) # This fulfills the pytorch.utils.data.Dataset interface# Split into train and valnum_train = int(N * TRAIN_FRAC)num_val = N - num_traindata_train, data_val = torch.utils.data.random_split(dataset, (num_train, num_val))dataloader_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE)dataloader_val = torch.utils.data.DataLoader(data_val, batch_size=BATCH_SIZE)# Sample batchx, y = next(iter(dataloader_train))x, youtput:
(tensor([[ 0, 13, 13, ..., 100, 100, 1],
[ 0, 6, 6, ..., 20, 20, 1],
[ 0, 2, 2, ..., 87, 87, 1],
...,
[ 0, 7, 7, ..., 102, 102, 1],
[ 0, 2, 2, ..., 19, 19, 1],
[ 0, 4, 4, ..., 117, 117, 1]]),
tensor([[ 0, 13, 44, ..., 125, 100, 1],
[ 0, 6, 106, ..., 74, 20, 1],
[ 0, 2, 87, ..., 81, 87, 1],
...,
[ 0, 7, 101, ..., 111, 102, 1],
[ 0, 2, 111, ..., 42, 19, 1],
[ 0, 4, 54, ..., 120, 117, 1]]))Model
Now we define the model class:




