KV cache at Whisper inference — Exploring OpenAI’s Whisper Model
This blog might be boring and short, only for my own record probably. OpenAI’s whisper github.
While diving into OpenAI’s Whisper GitHub repository, I came across some intriguing implementations, one of them is regarding the KV (key-value) cache in the MultiHeadAttention
module.
The Role of KV Cache: it stores key and value tensors computed for previous positions. This caching mechanism enhances efficiency, especially in repetitive calculations.
How Does It Work?
In model code:
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
- Setting Up KV Cache Hooks: The Whisper model allows optional incorporation of a KV cache. Through the
install_kv_cache_hooks
method, it initializes a dictionary for storing all caches. It also sets up necessary hooks for key and value projection modules. These hooks save intermediate tensors, facilitating their reuse in later calculations. - Saving to Cache: The model determines whether to save new output to the cache based on the module’s presence in the cache and the output shape. It either stores the output as-is or concatenates it with existing cache, detaching it for future use.
- Applying Hooks: Hooks are applied to the
MultiHeadAttention
layers within the model. These hooks capture the output of the key and value projection modules, storing them in the cache.
The Forward Hook Mechanism: the register_forward_hook
function, a crucial component of PyTorch, is utilized here. This function hooks into the module, activating after the forward
method computes an output. It's a key element in managing the KV cache during inference. reference: https://github.com/pytorch/pytorch/blob/fe01605830145b5aa204120b90361021a2952ac1/torch/nn/modules/module.py#L1421
Note: this function is also used in activation checkpointing technique
In inference code:
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
key_modules = [block.attn.key for block in self.model.decoder.blocks]
value_modules = [block.attn.value for block in self.model.decoder.blocks]
self.kv_modules = key_modules + value_modules
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
This code highlights the initialization process, focusing on capturing key and value tensors, and illustrates how the logits
function utilizes the KV cache to enhance the efficiency of the model.
Again, as we all know, KV cache reveals efficient methods used to enhance the model inference performance. The use of PyTorch’s forward hook functionality stands out as a significant aspect of this process. I just want to see how it is used in that popular codebase :) Cheers! my dumb blog.