avatarPaula Ceccon Ribeiro

Free AI web copilot to create summaries, insights and extended knowledge, download it at here

6389

Abstract

0/0*HcWcN4fzH0EL3dlV"><figcaption></figcaption></figure><p id="73ed">These steps can be implemented in Python with the following code:</p><div id="2a8f"><pre><span class="hljs-keyword">import</span> torch <span class="hljs-keyword">import</span> torch.nn.functional <span class="hljs-keyword">as</span> F <span class="hljs-keyword">from</span> math <span class="hljs-keyword">import</span> sqrt

<span class="hljs-keyword">def</span> <span class="hljs-title function_">scaled_dot_product_attention</span>(<span class="hljs-params">query, key, value</span>): <span class="hljs-string">""" Simplified scaled dot product attention. """</span> d_k = query.size(-<span class="hljs-number">1</span>)
scores = torch.bmm(query, key.transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>)) / sqrt(d_k)
weights = F.softmax(scores, dim=-<span class="hljs-number">1</span>)

<span class="hljs-keyword">return</span> torch.bmm(weights, value)  </pre></div><p id="e01d">To have a glimpse of how the attention weights are calculated, we can use the <a href="https://github.com/jessevig/bertviz">BertViz</a> library, specifically the <a href="https://github.com/jessevig/bertviz">neuronview</a> module:</p><div id="64f1"><pre><span class="hljs-keyword">from</span> bertviz.transformers_neuron_view <span class="hljs-keyword">import</span> BertModel, BertTokenizer

<span class="hljs-keyword">from</span> bertviz.neuron_view <span class="hljs-keyword">import</span> show

model_ckpt = <span class="hljs-string">"bert-base-uncased"</span> model = BertModel.from_pretrained(model_ckpt, output_attentions=<span class="hljs-literal">True</span>) tokenizer = BertTokenizer.from_pretrained(model_ckpt, do_lower_case=<span class="hljs-literal">True</span>)

text = <span class="hljs-string">"The quick brown fox jumps over the lazy dog"</span> show(model, <span class="hljs-string">"bert"</span>, tokenizer, text, display_mode=<span class="hljs-string">"light"</span>, layer=<span class="hljs-number">0</span>, head=<span class="hljs-number">8</span>)</pre></div><p id="5a3d">This visualization depicts the query vector, key vectors, and their product as vertical bands. The color intensity of the bands indicates their magnitude, while the hue represents their sign (blue for positive, orange for negative). The thickness of the connections between the lines is weighted based on the attention between the tokens.</p><figure id="f3fd"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/1*XWzsQP8bSPGPkFSseZoWfQ.gif"><figcaption></figcaption></figure><p id="b5af">For more on visualizing attention in Transformers, I'd recommend checking this great article:</p><div id="ffff" class="link-block"> <a href="https://generativeai.pub/explainable-ai-visualizing-attention-in-transformers-4eb931a2c0f8"> <div> <div> <h2>Explainable AI: Visualizing Attention in Transformers</h2> <div><h3>And logging the results in an experiment tracking tool</h3></div> <div><p>generativeai.pub</p></div> </div> <div> <div style="background-image: url(https://miro.readmedium.com/v2/resize:fit:320/1*b-q08ybKdDRSrQqhOOT6TA.png)"></div> </div> </div> </a> </div><h1 id="deae">Multi-Head Attention</h1><p id="e1b7">The multi-head attention is an <b>extension</b> of the self-attention mechanism. It enhances the modeling capability by performing multiple attention computations in parallel, with different learned linear projections.</p><p id="9ad9">The reasoning for heaving multi-head attention is that the softmax of one head usually focuses on mostly a single aspect of similarity. In other words, the multi-head attention allows the model to capture different types of dependencies and relationships between the elements in the input sequence. I.e., each <i>head</i>ᵢ ∈ <i>h</i> can attend to different parts of the sequence, enabling the model to learn more nuanced patterns.</p><figure id="870b"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/1*hhQocKzjpNh4sZGUoAWgGQ.png"><figcaption>Multi-Head Attention Mechanism</figcaption></figure><p id="2e6f">This can be translated to the following equation:</p><figure id="0fd7"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/0*j1kO1t0ZwcpsBTh0"><figcaption></figcaption></figure><p id="b0d6">where:</p><figure id="673b"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/0*YqJU3SZhj8u_DXgK"><figcaption></figcaption></figure><p id="7ecb">Where the projections are parameter matrices such as:</p><figure id="d7cf"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/0*wDa6m_dbA9FaWiG2"><figcaption></figcaption></figure><p id="7959">As can be seen, each attention head has its own set of <i>learnable parameters</i>. They perform computations on the input sequence and produce different representations. A single attention head can be implemented in Python as the following:</p><div id="647a"><pre><span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn

<span class="hljs-keyword">class</span> <span class="hljs-title class_">AttentionHead</span>(nn.Module): <span class="hljs-string">""" Self-attention head.

Args:
    embed_dim: embedding dimension.
    head_dim: number of dimensions we are projecting into.
"""</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, embed_dim, head_dim, mask=<span class="hljs-literal">None</span></span>):
    <span class="hljs-built_in">super</span>().__init__()
    self.q = nn.Linear(embed_dim, head_dim)
    self.k = nn.Linear(embed_dim, head_dim)
    self.v = nn.Linear(embed_dim, head_dim)

<span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, hidden_state</span>):
    attn_outputs = scaled_dot_product_attention(
        self.q(hidden_state), self.k(hidden_state), self.v(hidden_state))
    <span class="hljs-keyword">return</span> attn_outputs  </pre></div><p id="29db">The <code>AttentionHead</code> class has three linear layers: <code>self.q</code>, <code>self.k</code>, and <code>self.v</code>.

Options

These layers perform linear projections of the initial input <code>hidden_state</code> to obtain the <code>Q</code>, <code>K</code>, and <code>V</code> vectors, respectively. The <code>nn.Linear(embed_dim, head_dim)</code> statements create these linear layers, where <code>embed_dim</code> represents the embedding dimension of the input and <code>head_dim</code> represents the number of dimensions the projections are reduced to.</p><p id="15d0">The linear layers <code>self.q</code>, <code>self.k</code>, and <code>self.v</code> in the <code>AttentionHead</code> class use learnable weight matrices to perform the projections. These weight matrices are automatically learned during training. Additionally, the <code>forward</code> method in the <code>AttentionHead</code> class applies the scaled dot-product attention mechanism using the <code>scaled_dot_product_attention</code> function. It takes the linearly projected versions of <code>Q</code>, <code>K</code>, and <code>V</code> as inputs and produces the attention outputs (<code>attn_outputs</code>), which are then returned.</p><p id="cec0">From that, the output of multiple attention heads can be concatenated to define a multi-head attention layer:</p><div id="4105"><pre><span class="hljs-keyword">import</span> torch <span class="hljs-keyword">from</span> torch <span class="hljs-keyword">import</span> nn

<span class="hljs-keyword">class</span> <span class="hljs-title class_">MultiHeadAttention</span>(nn.Module): <span class="hljs-string">""" Multi-head attention.

Args:
    config: multi-head attention configuration.
"""</span>
<span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, config, mask=<span class="hljs-literal">None</span></span>):
    <span class="hljs-built_in">super</span>().__init__()
    embed_dim = config.hidden_size
    num_heads = config.num_attention_heads
    head_dim = embed_dim // num_heads
    self.heads = nn.ModuleList(
        [AttentionHead(embed_dim, head_dim) <span class="hljs-keyword">for</span> _ <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(num_heads)]
    )
    self.output_linear = nn.Linear(embed_dim, embed_dim)

<span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, hidden_state</span>):
    <span class="hljs-string">"""
    Concatenate the outputs of each self-attention layer.
    """</span>
    x = torch.cat([h(hidden_state) <span class="hljs-keyword">for</span> h <span class="hljs-keyword">in</span> self.heads], dim=-<span class="hljs-number">1</span>)
    x = self.output_linear(x)
    <span class="hljs-keyword">return</span> x </pre></div><p id="3e34">The <code>head_dim</code> parameter represents the number of dimensions to which the input embeddings are projected within each attention head. Determining the appropriate value for <code>head_dim</code> depends on factors such as the size of the input embeddings and the desired level of granularity in capturing relationships and dependencies.</p><p id="b37b">In practice, the <code>head_dim</code> is often chosen as a fraction or multiple of the embedding dimension (<code>embed_dim</code>). By using a multiple of <code>embed_dim</code>, the computation across each attention head can be more efficient.</p><p id="0832">We can once again use <a href="https://github.com/jessevig/bertviz">BertViz</a> library, specifically the <a href="https://github.com/jessevig/bertviz#model-view">head_view</a> module, to visualize the attention for one or more attention heads in the same layer.</p><div id="8d38"><pre><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoModel, AutoTokenizer

model_ckpt = <span class="hljs-string">"bert-base-uncased"</span> model = AutoModel.from_pretrained(model_ckpt, output_attentions=<span class="hljs-literal">True</span>) tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

sentence_a = <span class="hljs-string">"The quick brown fox jumps over the lazy dog"</span> sentence_b = <span class="hljs-string">"How quickly daft jumping zebras vex!"</span>

viz_inputs = tokenizer(sentence_a, sentence_b, return_tensors=<span class="hljs-string">'pt'</span>) attention = model(**viz_inputs).attentions sentence_b_start = (viz_inputs.token_type_ids == <span class="hljs-number">0</span>).<span class="hljs-built_in">sum</span>(dim=<span class="hljs-number">1</span>) tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[<span class="hljs-number">0</span>])

head_view(attention, tokens, sentence_b_start, heads=[<span class="hljs-number">8</span>])</pre></div><figure id="71b7"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/1*YJQz7itCIT5BYBmGPFbUMA.gif"><figcaption></figcaption></figure><h1 id="5743">References</h1><ul><li><a href="https://arxiv.org/abs/1706.03762">Attention Is All You Need</a></li><li><a href="https://github.com/nlp-with-transformers/notebooks/blob/main/03_transformer-anatomy.ipynb">Transformer Anatomy</a></li><li><a href="https://www.amazon.com.au/Natural-Language-Processing-Transformers-Applications/dp/1098103246">NLP with Transformers</a></li><li><a href="https://www.amazon.com.au/Advanced-Deep-Learning-Python-Vasilev/dp/178995617X">Advanced Deep Learning with Python</a></li></ul><h1 id="99c9">General Knowledge</h1><p id="e475">All sentences used here are <a href="https://en.wikipedia.org/wiki/Pangram#">pangrams</a>. Pangrams are sentences or phrases that contain every letter of the alphabet at least once. They are often used as a typing exercise or as a test to check the functionality of fonts or keyboards.</p><figure id="8aee"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/0*brpxECw6goRJ3WMi.png"><figcaption></figcaption></figure><p id="220d"><b>This story is published on <a href="https://generativeai.pub/">Generative AI</a>. Connect with us on <a href="https://www.linkedin.com/company/generative-ai-publication">LinkedIn</a> to get the latest AI stories and insights right in your feed. Let’s shape the future of AI together!</b></p><figure id="2e65"><img src="https://cdn-images-1.readmedium.com/v2/resize:fit:800/0*b7EPaVnQxwtfxfYp.png"><figcaption></figcaption></figure></article></body>

Transformers from Scratch: Part 1

Key Concepts, Self-Attention, Multi-Head Attention

Photo by Jéan Béller on Unsplash

Attention is a mechanism that allows neural networks to focus on different parts of the input sequence when processing information. It is a crucial component of the transformer architecture, enabling the model to capture the sequence's dependencies and relationships between different elements. For text sequences, the elements are token embeddings.

In a transformer model, attention is computed through the self-attention mechanism.

Disclaimer: I’m purposely not touching masks as this is only relevant for the decoder part of the Transformer, which is going to be tackled in Part 2. My main goal here is to explain the basics of the attention mechanism.

But let's start with first things first…

Query, Key, and Value Vectors

In the context of attention mechanisms, each element in the input sequence is associated with a query, key, and value vector.

Imagine you’re attending a conference where multiple speakers give presentations. Each presentation corresponds to a token in the input sequence. Now, let’s break down the key, query, and value in this context:

  1. Key: The key represents the content or context of each presentation. It captures the main ideas, themes, or relevant information associated with each talk. Think of the key as a summary or representation of the key points of each presentation.
  2. Query: The query represents the specific topic or question you’re interested in or want to focus on during the conference. It could be a specific area of interest or a particular subject you’re curious about. The query reflects your current context or the aspect you want to explore further.
  3. Value: The value contains the detailed information, insights, or knowledge provided by each speaker during their presentation. It encompasses all the valuable content of each talk, including facts, examples, explanations, and ideas.

Hold on to this example as we explore the attention mechanism.

Self-Attention

The self-attention mechanism calculates attention weights that indicate the relevance of each element with respect to the other elements within the same sequence. These weights indicate the degree of attention that should be given to each element. They are typically computed based on the similarity between the query, key, and value vectors associated with each element.

The term “self” in self-attention emphasizes that attention is computed within the same sequence, without considering any external context or other sequences. It highlights the capability of the self-attention mechanism to capture dependencies and relationships between elements within the input sequence itself.

In the scenario previously described, the attention mechanism allows you to attend to relevant presentations and extract valuable information based on your query. The key vectors help determine which presentations are most relevant to your query, while the query vector represents your specific area of interest or focus. The value vectors contain the detailed content of each presentation.

The model identifies the most important presentations that align with your interests by calculating attention weights between the query and the keys. It then combines the values of these selected presentations using the attention weights, effectively capturing the relevant information from each presentation based on your query.

Scaled Dot-Product Attention

The scaled dot-product attention is the most common way to implement a self-attention layer. It computes the attention weights between a query vector and a set of key-value pairs by calculating the dot product similarity between them. The key idea behind the scaled dot-product attention is to scale the dot products by the square root of the dimensionality of the query and key vectors, which helps stabilize the gradients during training.

The computation of the scaled dot-product attention can be summarized by the following steps:

  1. Project each token embedding, with dimension dₘ, into three vectors: query vector Q, and a set of key vectors K, both with dimensions dₖ and value vectors V with dimension dᵥ.
  2. Compute the attention scores using the dot product similarity. The dot product between the query vector Q and the key vectors K^T for a sequence with dₖ input tokens will yield a similarity matrix of dimensions dₖ × d.
  3. Scale the similarity matrix by dividing it by the square root of the dimensionality of the query/key vectors. This scaling ensures that the dot product values are not too large and helps prevent gradient explosion during training.
  4. Compute attention weights w. Apply the softmax function to the scaled similarity matrix. The resulting attention weights represent the importance of each key with respect to each query.
  5. Update the token embeddings. Multiply the attention weights w by the value vectors V to obtain a weighted sum of the values. The output is a weighted representation of the values based on the attention weights.
Self Attention Computation

This can be translated into the following equation:

These steps can be implemented in Python with the following code:

import torch
import torch.nn.functional as F
from math import sqrt

def scaled_dot_product_attention(query, key, value):
    """
    Simplified scaled dot product attention.
    """
    d_k = query.size(-1)  
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(d_k)  
    weights = F.softmax(scores, dim=-1)
    
    return torch.bmm(weights, value)  

To have a glimpse of how the attention weights are calculated, we can use the BertViz library, specifically the neuronview module:

from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show

model_ckpt = "bert-base-uncased"
model = BertModel.from_pretrained(model_ckpt, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_ckpt, do_lower_case=True)

text = "The quick brown fox jumps over the lazy dog"
show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

This visualization depicts the query vector, key vectors, and their product as vertical bands. The color intensity of the bands indicates their magnitude, while the hue represents their sign (blue for positive, orange for negative). The thickness of the connections between the lines is weighted based on the attention between the tokens.

For more on visualizing attention in Transformers, I'd recommend checking this great article:

Multi-Head Attention

The multi-head attention is an extension of the self-attention mechanism. It enhances the modeling capability by performing multiple attention computations in parallel, with different learned linear projections.

The reasoning for heaving multi-head attention is that the softmax of one head usually focuses on mostly a single aspect of similarity. In other words, the multi-head attention allows the model to capture different types of dependencies and relationships between the elements in the input sequence. I.e., each headᵢ ∈ h can attend to different parts of the sequence, enabling the model to learn more nuanced patterns.

Multi-Head Attention Mechanism

This can be translated to the following equation:

where:

Where the projections are parameter matrices such as:

As can be seen, each attention head has its own set of learnable parameters. They perform computations on the input sequence and produce different representations. A single attention head can be implemented in Python as the following:

from torch import nn

class AttentionHead(nn.Module):
    """
    Self-attention head.
    
    Args:
        embed_dim: embedding dimension.
        head_dim: number of dimensions we are projecting into.
    """
    def __init__(self, embed_dim, head_dim, mask=None):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def forward(self, hidden_state):
        attn_outputs = scaled_dot_product_attention(
            self.q(hidden_state), self.k(hidden_state), self.v(hidden_state))
        return attn_outputs  

The AttentionHead class has three linear layers: self.q, self.k, and self.v. These layers perform linear projections of the initial input hidden_state to obtain the Q, K, and V vectors, respectively. The nn.Linear(embed_dim, head_dim) statements create these linear layers, where embed_dim represents the embedding dimension of the input and head_dim represents the number of dimensions the projections are reduced to.

The linear layers self.q, self.k, and self.v in the AttentionHead class use learnable weight matrices to perform the projections. These weight matrices are automatically learned during training. Additionally, the forward method in the AttentionHead class applies the scaled dot-product attention mechanism using the scaled_dot_product_attention function. It takes the linearly projected versions of Q, K, and V as inputs and produces the attention outputs (attn_outputs), which are then returned.

From that, the output of multiple attention heads can be concatenated to define a multi-head attention layer:

import torch
from torch import nn

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention.
    
    Args:
        config: multi-head attention configuration.
    """
    def __init__(self, config, mask=None):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        """
        Concatenate the outputs of each self-attention layer.
        """
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        x = self.output_linear(x)
        return x 

The head_dim parameter represents the number of dimensions to which the input embeddings are projected within each attention head. Determining the appropriate value for head_dim depends on factors such as the size of the input embeddings and the desired level of granularity in capturing relationships and dependencies.

In practice, the head_dim is often chosen as a fraction or multiple of the embedding dimension (embed_dim). By using a multiple of embed_dim, the computation across each attention head can be more efficient.

We can once again use BertViz library, specifically the head_view module, to visualize the attention for one or more attention heads in the same layer.

from transformers import AutoModel, AutoTokenizer

model_ckpt = "bert-base-uncased"
model = AutoModel.from_pretrained(model_ckpt, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

sentence_a = "The quick brown fox jumps over the lazy dog"
sentence_b = "How quickly daft jumping zebras vex!"

viz_inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt')
attention = model(**viz_inputs).attentions
sentence_b_start = (viz_inputs.token_type_ids == 0).sum(dim=1)
tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[0])

head_view(attention, tokens, sentence_b_start, heads=[8])

References

General Knowledge

All sentences used here are pangrams. Pangrams are sentences or phrases that contain every letter of the alphabet at least once. They are often used as a typing exercise or as a test to check the functionality of fonts or keyboards.

This story is published on Generative AI. Connect with us on LinkedIn to get the latest AI stories and insights right in your feed. Let’s shape the future of AI together!

Transformers
NLP
Attention
Machine Learning
Pytorch
Recommended from ReadMedium