avatarHao Zhuang, an engineer, Tesla AI, Ex-Googler, PhD

Summary

The blog post explores the implementation and benefits of the KV (key-value) cache mechanism in OpenAI's Whisper model, particularly focusing on its role in enhancing inference efficiency.

Abstract

The article delves into the intricacies of OpenAI's Whisper model, with a particular emphasis on the KV cache within the MultiHeadAttention module. The KV cache is designed to store key and value tensors from previous positions, which is crucial for reducing computational redundancy during repetitive calculations. The post describes the install_kv_cache_hooks method, which initializes the cache and sets up hooks to capture and reuse intermediate tensors. It also details the use of PyTorch's register_forward_hook function to manage the cache during inference, and how this mechanism is similar to the activation checkpointing technique. The author acknowledges the efficiency gains achieved through the KV cache, noting its importance in improving the model's inference performance and provides references for further understanding of PyTorch hooks.

Opinions

  • The author suggests that the blog post might be of limited interest, potentially serving only as a personal record.
  • The use of KV cache is recognized as an efficient method to enhance model inference performance.
  • The author expresses admiration for the implementation of the KV cache in the Whisper model, particularly the use of PyTorch's forward hook functionality.
  • There is an acknowledgment of the broader community's familiarity with the benefits of KV caching for improving computational efficiency.
  • The author's tone implies a level of enthusiasm and appreciation for the technical sophistication of the Whisper model's inference code.

KV cache at Whisper inference — Exploring OpenAI’s Whisper Model

This blog might be boring and short, only for my own record probably. OpenAI’s whisper github.

https://openai.com/research/whisper

While diving into OpenAI’s Whisper GitHub repository, I came across some intriguing implementations, one of them is regarding the KV (key-value) cache in the MultiHeadAttention module.

The Role of KV Cache: it stores key and value tensors computed for previous positions. This caching mechanism enhances efficiency, especially in repetitive calculations.

How Does It Work?

In model code:

def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
        tensors calculated for the previous positions. This method returns a dictionary that stores
        all caches, and the necessary hooks for the key and value projection modules that save the
        intermediate tensors to be reused during later calculations.

        Returns
        -------
        cache : Dict[nn.Module, torch.Tensor]
            A dictionary object mapping the key/value projection modules to its cache
        hooks : List[RemovableHandle]
            List of PyTorch RemovableHandle objects to stop the hooks to be called
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks
  1. Setting Up KV Cache Hooks: The Whisper model allows optional incorporation of a KV cache. Through the install_kv_cache_hooks method, it initializes a dictionary for storing all caches. It also sets up necessary hooks for key and value projection modules. These hooks save intermediate tensors, facilitating their reuse in later calculations.
  2. Saving to Cache: The model determines whether to save new output to the cache based on the module’s presence in the cache and the output shape. It either stores the output as-is or concatenates it with existing cache, detaching it for future use.
  3. Applying Hooks: Hooks are applied to the MultiHeadAttention layers within the model. These hooks capture the output of the key and value projection modules, storing them in the cache.

The Forward Hook Mechanism: the register_forward_hook function, a crucial component of PyTorch, is utilized here. This function hooks into the module, activating after the forward method computes an output. It's a key element in managing the KV cache during inference. reference: https://github.com/pytorch/pytorch/blob/fe01605830145b5aa204120b90361021a2952ac1/torch/nn/modules/module.py#L1421

https://openai.com/research/whisper

Note: this function is also used in activation checkpointing technique

In inference code:

class PyTorchInference(Inference):
    def __init__(self, model: "Whisper", initial_token_length: int):
        self.model: "Whisper" = model
        self.initial_token_length = initial_token_length
        self.kv_cache = {}
        self.hooks = []

        key_modules = [block.attn.key for block in self.model.decoder.blocks]
        value_modules = [block.attn.value for block in self.model.decoder.blocks]
        self.kv_modules = key_modules + value_modules

    def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
        if not self.kv_cache:
            self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()

        if tokens.shape[-1] > self.initial_token_length:
            # only need to use the last token except in the first forward pass
            tokens = tokens[:, -1:]

        return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)

This code highlights the initialization process, focusing on capturing key and value tensors, and illustrates how the logits function utilizes the KV cache to enhance the efficiency of the model.

Again, as we all know, KV cache reveals efficient methods used to enhance the model inference performance. The use of PyTorch’s forward hook functionality stands out as a significant aspect of this process. I just want to see how it is used in that popular codebase :) Cheers! my dumb blog.

References:

OpenAI
Pytorch
Whisper
Recommended from ReadMedium