Transforming Next-Token Prediction into Classification with LLMs
From tokens to labels: Performing classification with large language models
Large Language Models (LLMs), trained on vast amounts of internet data, are versatile and can perform a wide range of natural language tasks. One common application is classification, a supervised learning task that categorizes subjects into pre-defined labels. Zero-shot and few-shots classification have become popular techniques, enabling LLMs to perform classification tasks with no training data or a few examples. However, for better accuracy, it is demonstrated that instruction fine-tuning can improve the performance by tuning LLMs with curated datasets.
Instruction Fine-Tuning LLMs
The common practice for instruction fine-tuning involves constructing a dataset consisting of question-and-answer pairs. Pre-trained LLMs are then further fine-tuned using these pairs in a supervised manner.
You can check my previous post for this approach.
Optionally, one can further improve the performance with Direct Preference Optimization (DPO), where the dataset is organized into pairs of preferred and less desirable combinations. It’s no surprise that the top-ranked open-source LLMs are trained using one of these approaches.
LLMs are Deep Neural Networks
Intuitively, these approaches use the fact that LLMs operate on a tokens-in, tokens-out architecture. In both instructed and preference datasets, text pairs are converted into tokens. Using cross-entropy loss and the auto-regressive nature of decoder only transformer, where the label token is copied from the input token but shifted by one, the weights of LLMs are updated.
Fundamentally, LLMs are deep neural networks, and it’s possible to modify their architecture for classification tasks. This may seem daunting due to the complexity of LLMs, but let’s break it down from the first principle. LLMs perform next token prediction based on the preceding context tokens. The token generated is one of the tokens in the model’s vocabulary, selected as the highest probability token from the logits of the cross-entropy loss operation.
Given this, can we replace the vocabulary with the classification labels? Yes we can! Instead of predicting the next token from the vocabulary, we are interested in predicting the category of a classification task given preceding context tokens. It is feasible by changing the LLMs’ head, commonly implemented as, lm_head
. In text generation, the lm_head
has a shape of (embedding dimension, vocabulary size)
. For classification, we modify it to (embedding dimension, number of classification)
.
Train LLMs for Classification Tasks
Load model and tokenizer
Let’s conduct an experiment to verify this approach works. First, we will start by using a pre-trained LLM and its tokenizer from HuggingFace. In this study, I choose a lightweight 3.8 billion parameters model, microsoft/Phi-3-mini-4k-instruct
.
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "microsoft/Phi-3-mini-4k-instruct"
device_map = "auto"
trust_remote_code = True
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map=device_map, trust_remote_code=trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
To simply our classification experiment, we’ll choose a binary classification task, knowing that we can extend it to multiple classes later. For binary classification, the number of classes is two, thereby torch.nn.Linear(hidden_size, 2)
. The to("cuda:<number>")
function specifies which GPU machine this layer is assigned to. Please ensure that this layer is assigned to the same device where the model is initially loaded using device_map = “auto”.
Modify LLMs Head
# Modify and fine-tune the model
hidden_size = 3072
model.lm_head = torch.nn.Linear(hidden_size, 2).to("cuda:3")
Another interesting experiment is to decide which layers are needed for fine-tuning. In this post, I will focus on the final normalization layer in the last block based on the assumption that the model has learned the contextual meaning of the tokens from all the other layers. This layer is the penultimate layer prior to the lm_head
.
First, I freeze all the weights in the layers by specifying param.requires_grad = False
. Second, I locate the the last normalization layers of the last block and change its weights to tunable by setting param.requires_grad = True
. The model class in HuggingFace allows one to navigate to any layers using the dot operation as demonstrated in the code snippet below.
# Fine-tune only the last block and final normalization layer
for param in model.parameters():
param.requires_grad = False
# Fine-tune only the final normalization layer
last_block = model.model.layers[-1].to("cuda:3")
final_norm = model.model.norm.to("cuda:3")
for param in final_norm.parameters():
param.requires_grad = True
Define Cross Entropy Loss
In my previous post, I described that the most import logits of the lm_head
are associated with the last token of a sentence. This is due to the self-attention mechanism, where the last token has attention scores from all preceding context tokens. Hence, we use logits[:, -1, :]
.
import torch
def calculate_loss_batch(input_batch, target_batch, model):
input_batch, target_batch = input_batch.to("cuda"), target_batch.to("cuda")
logits = model(input_batch).logits[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch).to("cuda")
return loss
Prepare Dataset
The dataset consists of two lists: texts and labels. First, I use the tokenizer to convert the strings in the texts into tokenized IDs. The tokenizer either truncates the tokens that are longer than the defined max length or pads the tokens when they are shorter than the max length.
ecoding = tokenizer(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
Second, I use the DataLoader
object from torch to iteratively feed the data for model tuning. The BinaryClassificationDataset
object inherits from torch’s Dataset
object, and we load this dataset into aDataLoader
object.
class BinaryClassificationDataset(Dataset):
pass
dataset = BinaryClassificationDataset(texts, labels, tokenizer, max_length)
# DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
Putting step 1 and 2 together, this is the complete implementation. I provide an example in the main function to showcase how to instantiate BinaryClassificationDataset
object, create a DataLoader
object, and iterate to generate data.
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
class BinaryClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
add_special_tokens=True,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].squeeze(0) # Remove batch dimension
attention_mask = encoding['attention_mask'].squeeze(0) # Remove batch dimension
return {
'input_ids': input_ids.long(),
'attention_mask': attention_mask.long(),
'labels': torch.tensor(label, dtype=torch.long)
}
# Example usage
if __name__ == "__main__":
# Sample data
texts = ["Hello, this is a sample sentence.", "Another sample text for classification."]
labels = [0, 1]
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define dataset
max_length = 128
dataset = BinaryClassificationDataset(texts, labels, tokenizer, max_length)
# DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# Iterate through the dataloader
for batch in dataloader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
print("Input IDs:", input_ids)
print("Attention Mask:", attention_mask)
print("Labels:", labels)
Train the Model
After defining the loss function and preparing dataset, we can start training the model. calculate_loss_loader
initializes the losses
to 0. Given a specified number of batches, num_batches
, we compute the loss for each batch using the calculate_loss_batch
function. The batched data is provided iteratively with the DataLoader
object to the calculate_loss_batch
function, and the loss is accumulated for each iteration. For reproducibility, I use torch.manual_seed(1234)
. In this experiment, I achieved a loss of 0.68.
def calculate_loss_loader(data_loader, model, num_batches=None):
losses = 0.
if len(data_loader) == 0:
return "Please provide dataset, data loader is empty."
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for index, batch in enumerate(data_loader):
if index < num_batches:
loss = calculate_loss_batch(
batch["input_ids"], batch["labels"], model
)
losses += loss.item()
else:
break
return total_loss / num_batches
torch.manual_seed(1234) # For reproducibility
with torch.no_grad():
train_loss = calculate_loss_loader(dataloader, model, num_batches=2)
Trade-offs
It is thrilling that we can repurpose transformers from next token prediction to classification prediction by reasoning through first principles. On the plus side, the pre-trained transformer model captures rich representations of language patterns. However, though lightweight, the model still has 3.8 billion parameters. Readers should be aware of the trade-offs between accuracy and training and inferencing throughput (i.e. generated tokens per second). I suggest starting with a smaller model as a baseline. For example, you can use the distilbert-base-unca
sed model, which has 67 million parameters. Alternatively, you can try other more traditional machine learning models, such as XGBoost.
Conclusion
In this blog post, I demonstrate how, using first principles, we can transform the problem of next-token prediction into classification label prediction. I hope this demo equips you with the knowledge to unpack the complex architecture of LLMs and apply the same concept to your domain-specific problems. It’s import to note that transformers may not be the optimized solution for every classification task, and you should train and evaluate your data against other smaller models as well. Thanks for reading, and happy learning!