avatarRohan Jagtap

Summary

XLNet is an advanced language model that improves upon BERT by using an autoregressive pre-training approach with a novel permutation language modeling objective, enabling it to outperform BERT on various language understanding tasks.

Abstract

XLNet is a state-of-the-art language model that addresses the limitations of BERT, such as the pretrain-finetune discrepancy and the independence assumption in masked token reconstruction. Unlike BERT's bidirectional context modeling with masked language modeling (MLM), XLNet employs an autoregressive approach, which avoids the need for input denoising. It introduces a permutation language modeling objective that allows the model to learn dependencies from all positions in a sequence without being constrained by a unidirectional context. This approach also includes a two-stream self-attention mechanism for target-aware representations and incorporates ideas from Transformer XL, such as relational encoding and segment recurrence, to handle long sequences efficiently. XLNet's code and pre-trained weights are publicly available, facilitating its adoption in various natural language processing tasks.

Opinions

  • The author suggests that XLNet's ability to outperform BERT in 20 tasks is a significant achievement, indicating its superiority in language understanding.
  • There is an emphasis on the drawbacks of BERT's pre-training objective, particularly the use of artificial symbols like [MASK] that do not appear in real data during fine-tuning.
  • The permutation language modeling objective of XLNet is highlighted as a key innovation that combines the strengths of autoencoding and autoregressive models while overcoming their individual limitations.
  • The two-stream self-attention mechanism is presented as a necessary adaptation for the permutation language model to function effectively with the Transformer architecture.
  • The partial prediction strategy in XLNet is seen as a practical solution to the computational challenges posed by permutation language modeling, allowing for efficient training and inference.
  • The integration of concepts from Transformer XL into XLNet is viewed positively, leveraging the benefits of relational encoding and segment recurrence for processing longer sequences.
  • The open-sourcing of XLNet's code and the availability of pre-trained weights through Hugging Face's transformers library are considered beneficial for the machine learning community, encouraging further research and application.

XLNet: Autoregressive Pre-Training for Language Understanding

Understanding Transformer-Based Self-Supervised Architectures

Photo by Tim Mossholder on Unsplash

State of the art Language Models like BERT, OpenAI GPT have been stellar in Natural Language Processing in recent times. These models are based on the Transformer architecture, which has driven RNN-based and Convolution-based models out of the business.

In this article, we’ll be discussing the XLNET model, which was proposed in a recent paper: XLNet: Generalized Autoregressive Pretraining for Language Understanding. This model has addressed certain drawbacks of BERT and has successfully overcome them by outperforming BERT in 20 tasks.

If you’re interested in knowing the concept behind BERT or Transformers, consider giving this (BERT) and this (Transformer) a read.

So, What’s Wrong With BERT?

Demo of Masked LM by AllenNLP

Input Noise

One major issue with BERT is essentially its pre-training objective on masked sequences i.e the Denoising Autoencoding objective. Masking the sequences greatly helps in understanding the trends in the language corpus, however, while fine-tuning, the sequences aren’t expected to be masked.

However, the artificial symbols like [MASK] used by BERT during pre-training are absent from real data at fine-tuning time, resulting in a pretrain-finetune discrepancy.

— XLNet Paper

Independence Assumption

BERT maximizes the joint conditional probability p(x_t | x_hat), where x_t is the masked term and x_hat is the sequence of tokens. It is read as, probability of a masked token x_t to occur at the ‘t’th position, given all the tokens in that sequence x_hat.

This gives the intuition of an independence assumption that each of the masked tokens are reconstructed separately. We’ll clear this in a later section.

XLNET

As opposed to BERT, XLNet is an auto-regressive model. This essentially removes its dependency on denoising the input.

However, autoregressive models are mostly criticized for their unidirectional nature. Hence, to overcome this, XLNet proposes a novel Permutation Language Modeling objective that overcomes this unidirectionality.

Permutation Language Modeling

As mentioned previously, XLNet proposes a mechanism that takes the good stuff from both worlds (i.e. autoencoding and autoregressive). It doesn’t have the denoising of inputs as in the autoencoding objective and removes the unidirectionality from a traditional autoregressive objective.

To achieve this, while factorizing the joint probability p(x_t | x_(i < t)), instead of using a fixed forward or backward factorization order as in traditional autoregressive models, XLNet maximizes the log-likelihood of a sequence w.r.t all possible permutations of the factorization order.

Specifically, for a sequence x of length T, there are T! different orders to perform a valid autoregressive factorization. Intuitively, if model parameters are shared across all factorization orders, in expectation, the model will learn to gather information from all positions on both sides.

— XLNet Paper

Permutation Language Modeling from XLNet Paper

To elaborate more on this objective, let’s take an example. Consider the above figure with a sequence x with 4 tokens. For simplicity, we consider the attention computation only for x_3. Observe the permutation order stated under each of the figures above.

  • While taking the order 3 -> 2 -> 4 -> 1, 3 happens to be the first token from the sequence. Hence, none of the other tokens contribute to its attention computation. Because they do not precede 3 in the current permutation.
  • In the order 2 -> 4 -> 3 -> 1, 3 is preceded by 2 and 4, hence they contribute to its attention computation.
  • Similarly, for 1 -> 4 -> 2 -> 3 and 4 -> 3 -> 1 -> 2, the corresponding x_(i < t) contributes to the attention computation of x_t.

More formally:

Objective Function from XLNet Paper

Note that, while training, it is not correct to actually obtain the permutation of the sequence, as the sequences can’t be permuted while fine-tuning on the downstream task or during inference. Hence, the attention mask in the Transformer is properly manipulated to obtain the correct permutations; which also makes sense because the proposed architecture talks about permuting over the factorization order and not the sequence order.

Two-Stream Self-Attention for Target-Aware Representations

Two-Stream Self-Attention for Target-Aware Representations from XLNet Paper

The regular Transformer parametrization may not work with the permutation language model. To understand this, let’s consider the standard formulation of the distribution using softmax which is given by:

Permutation LM with Standard Transformer Parametrization

Here, the term h_θ(x_(z_(< t))), is the hidden state of the transformer for x_(z_(<t)). This term, is in no way dependent on the position that it predicts i.e. z_(<t). This means, that whatever position is being predicted, this distribution will be the same; thus posing the inability to learn useful trends.

Hence, to overcome this, the XLNet paper proposes a re-parametrization for the next token distribution to be target aware:

Permutation LM with Re-Parametrized Representation

A modified representation g_θ is used, which additionally takes the target position z_t as the input. So, two hidden states are used instead of one:

Content Stream Attention
  • The content representation, which is essentially the same as the standard Transformer hidden state. This representation encodes both; the context x_(z_(<t)) as well as the original token x_(z_t).

Mathematically:

Content Representation
Query Stream Attention
  • The query representation, that has access only to contextual information x_(z_(<t)) and the position of the target z_t.

Mathematically:

Query Representation

Note that, initially the content stream (h_i) is essentially the corresponding embedding vector (e_x_i), and the query stream (g_i) is a trainable vector (w) initially. These are updated over each layer using the above expressions.

Partial Prediction

Keeping aside all the benefits that permutation LM has, we gotta accept that it is expensive. It is a challenging optimization problem due to permutation.

Hence, to solve this, in a given sequence z, only a subsequence z_(>c) is selected for prediction, where c is called the cutting point. We consider only z_(>c) since it has the longest context in that sequence.

Moreover, another hyperparameter K is used such that, K ~ |z|/(|z|−c). And we select only 1/K tokens for prediction. For unselected tokens, their query representations aren’t computed, which saves speed and memory.

We compare this partial prediction to that of the BERT. BERT uses partial prediction because masking all the tokens doesn’t make any sense. XLNet does partial prediction because of the optimization difficulty. For example: let’s have a sequence: [Deep, Learning, is, great]. Say both BERT and XLNet opt to predict the tokens [Deep, Learning]. Also suppose that XLNet factorizes the sample as [is, great, Deep, Learning]. In this case,

BERT maximizes:

  • L(BERT) = log p(Deep | is great) + log p(Learning | is great)

XLNet maximizes:

  • L(XLNet) = log p(Deep | is great) + log p(Deep | Learning is great)

This clearly explains how XLNet captures more dependency i.e. between Deep and Learning. No doubt that BERT learns most of the dependencies; but XLNet learns more. Also, this is an example of the independence assumption in BERT which was covered in the previous section.

Taking Ideas from Transformer XL

Finally, a mention of the Transformer XL model from where XLNet borrows the ideas of relational encoding and the segment recurrence mechanism which enables Transformer XL to operate on very long sequences.

Fun Fact: Transformer XL can attend sequences that 80% longer than RNNs and 450% longer than vanilla Transformer and it is 1800+ times faster than vanilla Transformers during evaluation.

Conclusion

We’ve covered another state of the art model, XLNet, and have discussed the concept behind it.

XLNet’s code is open-sourced by the authors, you can find it here.

You can find the pre-trained weights and an easy to use API for the model architecture by huggingface transformers.

New: I have written a paper summary and critique for XLNet. You can take a look if interested: https://docs.google.com/document/d/1nePIW67OqW1HPrIkoXUB-N8hK2A07-3pnmtRVMQMKTA/edit?usp=sharing

References

Artificial Intelligence
Deep Learning
NLP
Machine Learning
Towards Data Science
Recommended from ReadMedium
avatarEnozeren
What does GPT mean?

6 min read