avatarShivanand Roy

Summary

This context provides a step-by-step guide on how to fine-tune a pretrained XLNet model for text classification using the Kaggle dataset "Real or Not? NLP with Disaster Tweets" and the simpletransformers library.

Abstract

The context begins by introducing XLNet, a powerful autoregressive model that outperforms BERT in 20 different tasks. The author then describes the dataset used for fine-tuning the XLNet model, which consists of 10,000 tweets hand-classified as being about real disasters or not. The author provides a detailed walkthrough of the code required to preprocess the data, install the necessary dependencies, and train the XLNet model using simpletransformers. The author also discusses the results of the training, which achieved a decent accuracy of 82.6% on the evaluation set. Finally, the author provides a link to the Google Colab notebook used for the training and suggests further improvements that could be made to the model.

Bullet points

  • XLNet is a powerful autoregressive model that outperforms BERT in 20 different tasks.
  • The dataset used for fine-tuning the XLNet model consists of 10,000 tweets hand-classified as being about real disasters or not.
  • The author provides a detailed walkthrough of the code required to preprocess the data, install the necessary dependencies, and train the XLNet model using simpletransformers.
  • The training achieved a decent accuracy of 82.6% on the evaluation set.
  • The author provides a link to the Google Colab notebook used for the training and suggests further improvements that could be made to the model.

Fine Tuning XLNet Model for Text Classification in 3 Lines of Code

Introduction

XLNet is powerful! It beats BERT and its other variants in 20 different tasks.

The XLNet model was proposed in XLNet: Generalized Autoregressive Pretraining for Language Understanding by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. XLnet is an extension of the Transformer-XL model pre-trained using an autoregressive method to learn bidirectional contexts by maximizing the expected likelihood over all permutations of the input sequence factorization order.

In simple words — XLNet is a generalized autoregressive model.

An Autoregressive model is a model which uses the context word to predict the next word. So, the next token is dependent on all previous tokens.

XLNET is generalized because it captures bi-directional context by means of a mechanism called “permutation language modeling”.

It integrates the idea of auto-regressive models and bi-directional context modeling, yet overcoming the disadvantages of BERT and thus outperforming BERT on 20 tasks, often by a large margin in tasks such as question answering, natural language inference, sentiment analysis, and document ranking.

In this article, we will take a pretrained XLNet model and fine tune it on our dataset.

So, let’s talk about the dataset.

Data

We will take a dataset from Kaggle’s text classification challenge (Ongoing as of now) — Real or Not? NLP with Disaster Tweets.

In this competition, we have to build a machine learning model that predicts which Tweets are about real disasters and which one’s aren’t. It’s a small dataset of 10,000 tweets that were hand classified.

We will use this data to fine tune a pretrained XLNet model.

Let’s Code

Installing Dependencies

  • First, lets spin up a Colab notebook.
  • Download the data from Real or Not? NLP with Disaster Tweets. You will have 3 files, train.csv, test.csv and sample_submission.csv
  • Upload it to your Colab Notebook session.
  • Install the latest stable pytorch 1.6, transformers and simpletransformers.

Now we’re good to go.

Preprocessing

First, let’s load the dataset

We have 5 columns in our data:

  • id: it is a unique identifier of tweets.
  • keyword: It contains the keywords made on the tweets.
  • location: The location the tweet was sent from.
  • text: it is actual tweet made by the users
  • target: Whether a given tweet is about a real disaster or not. If so, predict a 1. If not, predict a 0.

Let’s look at the distribution of target class

0 4342
1 3271
Name: target, dtype: int64

The dataset is pretty much balanced. We have 3271 tweets about disasters while 4342 tweets otherwise.

Let’s have a look at the keyword and location columns

Keyword column has 0.80% null values
Location column has 33.27% null values

location has 33% missing values while keyword has 0.8% null values. We will not delve into filling up missing values and will leave these columns as it is.

The text and target columns is of our interest.

Let’s have a look at the text column

['Two giant cranes holding a bridge collapse into nearby homes http://t.co/jBJRg3eP1Q',
"Apollo Brown - 'Detonate' f. M.O.P. | http://t.co/H1xiGcEn7F",
'Listening to Blowers and Tuffers on the Aussie batting collapse at Trent Bridge reminds me why I love @bbctms! Wonderful stuff! #ENGvAUS',
'Downtown Emergency Service Center is hiring! #Chemical #Dependency Counselor or Intern in #Seattle apply now! #jobs http://t.co/HhTwAyT4yo', 
'Car engulfed in flames backs up traffic at Parley\x89Ûªs Summit http://t.co/RmucfjCaZr', 
'After death of Palestinian toddler in arson\nattack Israel cracks down on Jewish',
'Students at Sutherland remember Australian casualties at Lone Pine Gallipoli\n http://t.co/d50oRfXoFB via @theleadernews',
'FedEx no longer to transport bioterror germs in wake of anthrax lab mishaps http://t.co/hrqCJdovJZ',
'@newyorkcity for the #international emergency medicine conference w/ Lennox Hill hospital and #drjustinmazur', 
'My back is so sunburned :(']

We see that the text columns contains #, @, and links which needs to be cleaned.

Let’s write a simple function to clean up:

We will use tweet-preprocessor to do this.

tweet-preprocessor.clean() function can help us get rid of irrelevant tokens such as any hashtags, @username or links from the tweet and make it super clean to feed into XLNet model.

100%7613/7613 [00:49<00:00, 154.19it/s]
100%3263/3263 [00:48<00:00, 67.34it/s]

Now, we have clean text in clean_text column.

Now, let’s split our data into train and eval set

((6090, 2), (1523, 2))

We divided our data into train_df and eval_df in 80:20 startified split. We have 6090 tweets for training and 1523 tweets for evaluation.

Now, we are all set for training XLNet.

XLNet Training

For training XLNet, we will use simpletransformers which is super easy to use library built on top of our beloved transformers.

simpletransformers has a unified functions to train any SOTA pretrained NLP model available in transformers. So you get the power of SOTA pretrained language models like BERT and its variants, XLNet, ELECTRA, T5 etc. wrapped in easy to use functions.

As you see below, it just takes 3 lines of code to train a XLNet model. And the same holds true for training it from scratch or just fine tuning the model on custom dataset.

I have kept num_train_epochs: 4, train_batch_size: 32 and max_seq_length: 128 - so that it fits into Colab compute limits. Feel free to play with a lot of parameters mentioned in args in the code below.

Downloading: 100% 760/760 [00:10<00:00, 71.0B/s] 
Downloading: 100% 467M/467M [00:10<00:00, 45.2MB/s] 
Downloading: 100% 798k/798k [00:14<00:00, 56.1kB/s] 
100% 6090/6090 [08:15<00:00, 12.29it/s] 
Epoch 4 of 4: 100% 4/4 [08:12<00:00, 123.24s/it] 
Epochs 0/4. Running Loss: 0.4059: 100% 191/191 [08:12<00:00, 2.58s/it] 
Epochs 1/4. Running Loss: 0.2305: 100% 191/191 [02:01<00:00, 1.57it/s] 
Epochs 2/4. Running Loss: 0.4360: 100% 191/191 [04:24<00:00, 1.38s/it] 
Epochs 3/4. Running Loss: 0.0260: 100% 191/191 [02:28<00:00, 1.28it/s] 
100% 1523/1523 [00:23<00:00, 65.14it/s] Running Evaluation: 100% 191/191 [00:20<00:00, 9.17it/s] 
INFO:simpletransformers.classification.classification_model:{'mcc': 0.6457675302369492, 'tp': 518, 'tn': 741, 'fp': 128, 'fn': 136, 'acc': 0.8266579120157583, 'eval_loss': 0.5341164009543184}

We have achieved a decent accuracy of 82.6% on our eval set. This accracy is just out of the box — means with no feature engineering, with no hyparameter-tuning. Just out of the box!

🥳 We’re Done!

Let’s submit the predictions to Kaggle and see where we stand.

INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used. 100% 3263/3263 [00:01<00:00, 3216.81it/s]
100%408/408 [00:38<00:00, 10.68it/s]

We’re in top 18%. It’s a good start considering XLNet out of the box performance - with no feature engineering at all.

Now, we have a decent baseline to improve our model upon.

Notebooks

Interested in Machine Learning/Deep Learning articles? Here is some more for you —

Originally published at https://shivanandroy.com on September 8, 2020.

Data Science
Deep Learning
NLP
Transformers
Xlnet
Recommended from ReadMedium