avatarFlorian June

Summary

Grouped-Query Attention (GQA) is an efficient attention mechanism that balances memory usage and inference quality, acting as an interpolation between Multi-Query Attention (MQA) and Multi-Head Attention (MHA).

Abstract

The Grouped-Query Attention (GQA) mechanism addresses the memory bottleneck in autoregressive decoding by reducing the size of the key-value cache while maintaining high inference quality. It serves as a middle ground between Multi-Query Attention (MQA), which saves memory but may compromise quality, and Multi-Head Attention (MHA), which offers superior quality but at a higher memory cost. GQA achieves this by sharing key and value heads across groups of query heads, thus allowing for a reduction in memory requirements without a significant drop in performance. This balance has made GQA increasingly popular among large language models, including Llama2 and Mistral 7B. The implementation of GQA in models like Llama2 involves modifying the attention code to handle grouped queries, keys, and values, and the use of functions like repeat_kv to align the number of heads for computation. This approach not only improves memory efficiency but also enhances server throughput by enabling larger batch sizes and reducing the memory footprint of the key-value cache.

Opinions

  • GQA is seen as an improvement over MQA and MHA, offering a compromise between memory efficiency and inference quality.
  • The adoption of GQA by prominent models like Llama2 and Mistral 7B suggests its effectiveness and potential as a standard in attention mechanisms.
  • Reducing the size of the KV cache without significantly impacting computational load is a key advantage of GQA.
  • The use of GQA is not only beneficial for reducing memory usage but also for increasing server throughput, allowing for more concurrent requests and larger batch processing.
  • The author of the article recommends an AI service, ZAI.chat, as a cost-effective alternative to ChatGPT Plus, highlighting its performance and affordability.

Grouped-Query Attention(GQA) Explained

From Principles to Llama2 Code Explanation

The standard practice for autoregressive decoding is to cache the keys and values of the previous tokens in the sequence to speed up attention computation. However, as the context window or batch size increases, the memory cost associated with the size of the key-value cache(kv cache) in the multi-head attention(MHA) model significantly increases.

Multi-Query attention(MQA) is a mechanism that uses only a single key-value head for multiple queries, which can save memory and greatly speed up decoder inference.

However, MQA may lead to a decrease in quality. In fact, we not only want fast inference, but also want the quality to be on par with MHA, so Grouped-query attention(GQA)[1] comes into play.

Grouped-query attention(GQA) is an interpolation of multi-query and multi-head attention. It achieves a quality similar to multi-head attention while maintaining a comparable speed to multi-query attention.

Since GQA is a newcomer, many famous large language models have not adopted it before. However, since its proposal, it has gained popularity among popular models such as Llama2[2] and Mistral 7B[3].

GQA

GQA can be seen as an intermediate or generalized form of MQA and MHA:

  • When there is only one group in GQA, it is referred to as MQA.
  • When the number of groups in GQA is equal to the number of attention heads, it is referred to as MHA.

Figure 1 provides a clear visualization of this relationship.

Figure 1: Overview of grouped-query method. Multi-head attention has H query, key, and value heads. Multi-query attention shares single key and value heads across all query heads. Grouped-query attention instead shares single key and value heads for each group of query heads, interpolating between multi-head and multi-query attention. Source: [1]

From Figure 1, it can be observed that GQA’s strategy is to enhance the inference quality by offering a modified version of MQA. This is achieved by using multiple keys and values heads but with fewer than the total number of query heads.

How to incorporate GQA into LLM

Here, let’s take a look at the approach of Llama 2. Below is the code for attention:

class Attention(nn.Module):
    """Multi-head attention module."""
    def __init__(self, args: ModelArgs):
        """
        Initialize the Attention module.

        Args:
            args (ModelArgs): Model configuration parameters.

        Attributes:
            n_kv_heads (int): Number of key and value heads.
            n_local_heads (int): Number of local query heads.
            n_local_kv_heads (int): Number of local key and value heads.
            n_rep (int): Number of repetitions for local heads.
            head_dim (int): Dimension size of each attention head.
            wq (ColumnParallelLinear): Linear transformation for queries.
            wk (ColumnParallelLinear): Linear transformation for keys.
            wv (ColumnParallelLinear): Linear transformation for values.
            wo (RowParallelLinear): Linear transformation for output.
            cache_k (torch.Tensor): Cached keys for attention.
            cache_v (torch.Tensor): Cached values for attention.

        """
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        # ColumnParallelLinear and RowParallelLinear are two common strategies for implementing model parallelism.
        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        # The dimension of wk and wv has changed.
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        # kv cache, used for caching keys and values
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        """
        Forward pass of the attention module.

        Args:
            x (torch.Tensor): Input tensor.
            start_pos (int): Starting position for caching.
            freqs_cis (torch.Tensor): Precomputed frequency tensor.
            mask (torch.Tensor, optional): Attention mask tensor.

        Returns:
            torch.Tensor: Output tensor after attention.

        """
        bsz, seqlen, _ = x.shape
        # The dimension of k and v has changed.
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        
        # Incorporate rotary position embedding
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        # Cache current token's kv
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        # Retrieve the previously cached keys and values
        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        # make the number of heads in kv and q the same
        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        # Self-attention
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

I have commented on the key processes of this code. There are a few points to note (assuming parallel processing is not required):

  • self.n_local_heads represents the number of heads in the original multi-head attention, also refers to the number of query heads.
  • self.n_local_kv_heads represents the number of key and value heads in GQA. This means that the cache sizes for keys and values can be reduced by a factor of self.n_rep = self.n_local_heads // self.n_local_kv_heads.
  • Since GQA reduces the size of the KV cache, but in actual computation(matrix multiply (GEMM) subroutine), it needs to match the number of query heads. Therefore, they need to be expanded back to their original size. The repeat_kv function is utilized to duplicate the keys/values and align them with the number of query heads.
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x       # MHA
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )                  # GQA or MQA

Let’s take a look at an example to see what repeat_kv does:

>>> x = torch.rand(1, 1, 4, 6)
>>> x
tensor([[[[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],
          [0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],
          [0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],
          [0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> print(x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim))
tensor([[[[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],
          [0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],
          [0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],
          [0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],
          [0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],
          [0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],
          [0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833],
          [0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833]]]])
>>>

To further understand the principles of GQA, I have drawn two diagrams. Figure 2 shows that the number of attention heads for keys and values is 4 (self.n_local_kv_heads = 4), and the number of attention heads for queries is 8(self.n_local_heads = 8):

Figure 2: Attention heads in GQA. Image by author.

As shown in Figure 3, after repeat_kv, the number of attention heads for keys and values is matched with the number of queries(each color represents a group, and the number of heads in each group has been expanded), and matrix multiply(GEMM) subroutine can be performed.

Figure 3: Each color represents a group, after repeat_kv, the number of heads in each group has been expanded. Image by author.

Conclusion

Whether it is GQA or MQA, neither of them can significantly reduce the computational load. Their main purpose is to reduce the need for storing a large amount of kv cache. As a result, the memory occupied by kv cache becomes smaller, allowing our LLM server to handle more requests, allowing for larger batch sizes and increased throughput.

Finally, if there are any errors or omissions in this text, please feel free to point them out.

References

[1]: J. Ainslie, J. Lee-Thorp, M. Jong, Y. Zemlyanskiy, F. Lebrón, S. Sanghai. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (2023). arXiv preprint arXiv:2305.13245.

[2]: H. Touvron, L. Martin, K. Stone, P. Albert, A. Almahairi, Y. Babaei, N. Bashlykov, S. Batra, P. Bhargava, S. Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models (2023). arXiv preprint arXiv:2307.09288.

[3] A. Jiang, A. Sablayrolles, A. Mensch, C. Bamford, D. Chaplot, et al. Mistral 7B (2023). arXiv preprint arXiv:2310.06825.

Large Language Models
Attention Mechanism
Deep Learning
Machine Learning
AI
Recommended from ReadMedium