avatarSkanda Vivek

Summary

The article provides a tutorial on creating a text classification model using the DistilBERT transformer model, fine-tuned to distinguish between spam and ham (non-spam) SMS messages.

Abstract

The article outlines a method for fine-tuning the DistilBERT model to build a custom spam classifier for SMS texts. It details the process of loading and preprocessing data from the SMS spam collection dataset from the UCI Machine Learning Repository, which contains 5,574 labeled SMS messages. The tutorial explains how to tokenize the text data, map label IDs, set up training parameters, and use the Hugging Face Transformers library to fine-tune the model. The author demonstrates the model's effectiveness, achieving an accuracy of 0.9885, and provides instructions for performing inference both locally and on the Hugging Face model hub. The article concludes by emphasizing the ease and accessibility of adapting state-of-the-art transformer models for custom language tasks using Hugging Face's tools and provides additional resources for further exploration in fine-tuning transformer models.

Opinions

  • The author praises the DistilBERT model for its efficiency, noting it retains 97% of BERT's language understanding capabilities while being faster and smaller.
  • The author expresses that the training approach of using softmax outputs from the teacher model is a "brilliantly simple idea."
  • The article conveys the effectiveness of the fine-tuned model, as evidenced by the high accuracy achieved on the spam classification task.
  • The author encourages readers to use Hugging Face's tools, implying they simplify the process of fine-tuning and deploying transformer models.
  • The author suggests that with the right data, adapting state-of-the-art transformer models to custom tasks is both easy and accessible.
  • The author invites readers to support their work by becoming Medium members and subscribing to their newsletter for weekly data-based perspectives.

Transformer Models For Custom Text Classification Through Fine-Tuning

A tutorial on how to build a spam classifier (or any other classifier) by fine-tuning the DistilBERT model

Fine-Tuned SMS Spam Classifier Model Output | Skanda Vivek

The DistiBERT model was released by the folks at Hugging Face, as a cheaper, faster alternative to large transformer models like BERT. It was originally introduced in a blog post. The way this model works — is by using a teacher-student training approach, where the “student” model is a smaller version of the teacher model. Then, instead of training the student on the ultimate target outputs (basically one-hot encodings of the label class), the model is trained on the softmax outputs of the original “teacher model”. This is a brilliantly simple idea, and the authors show that:

“it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster.”

Loading and Preprocessing the Data For Classification

In this example, I use the SMS spam collection dataset in the UCI Machine Learning Repository and build a classifier that detects SPAM vs HAM (not SPAM). The data contains 5,574 rows of SMS texts that are labeled as SPAM or HAM.

First, I make train and validation files from the original csv and use the load_dataset function from the Hugging Face datasets library.

from datasets import load_dataset
import pandas as pd

df=pd.read_csv(‘/content/spam.csv’, encoding = “ISO-88591”)
df=df[['v1','v2']]
df.columns=['label','text']
df.loc[df['label']=='ham','label']=0
df.loc[df['label']=='spam','label']=1
df2[:4179].reset_index(drop=True).to_csv('df_train.csv',index=False)
df2[4179:].reset_index(drop=True).to_csv('df_test.csv',index=False)

dataset = load_dataset('csv', data_files={'train': '/content/df_train.csv',
                                              'test': '/content/df_test.csv'})

The next step is to load in the DistilBERT tokenizer to preprocess the text data.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(“distilbert-base-uncased”)

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True,padding=True)

tokenized_data = dataset.map(preprocess_function, batched=True)

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Training the model

Prior to training, you need to map IDs to labels. After this, you need to specify the training hyperparameters, call trainer.train() to begin fine-tuning, and push the trained model to the Hugging Face hub using trainer.push_to_hub().

id2label = {0: “HAM”, 1: “SPAM”}
label2id = {“HAM”: 0, “SPAM”: 1}

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id

training_args = TrainingArguments(
    output_dir="spam-classifier",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.push_to_hub()

That’s it! As you can see from the Hugging Face hub, the model accuracy is pretty good (0.9885)!

Model Inference

Inference is also relatively straightforward. You can see the output through running python scripts as below:

text = “Email AlertFrom: Ash Kopatz. Click here to get a free prescription refill!”

from transformers import pipeline

classifier = pipeline("sentiment-analysis", model="skandavivek2/spam-classifier")
classifier(text)
Sample Fine-Tuned SMS Spam Classifier Model Output | Skanda Vivek

Or run on the Hugging Face hub:

Takeaways

And that’s it! Hugging Face makes it very easy and accessible to adapt state of the art transformer models to custom language tasks as long as you have the data!

Here is the GitHub link to the code:

If you liked this blog, check out my other blog on fine-tuning Transformers for Question Answering!

References:

  1. https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset
  2. Dua, D. and Graff, C. (2019). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science.
  3. Almeida, T.A., Gómez Hidalgo, J.M., Yamakami, A. Contributions to the Study of SMS Spam Filtering: New Collection and Results. Proceedings of the 2011 ACM Symposium on Document Engineering (DOCENG’11), Mountain View, CA, USA, 2011.
  4. https://huggingface.co/docs/transformers/training

If you are not yet a Medium member and want to support writers like me, feel free to sign-up through my referral link: https://skanda-vivek.medium.com/membership

For weekly data-based perspectives subscribe here!

Artificial Intelligence
NLP
Data Science
Language
AI
Recommended from ReadMedium