Chapter 18. The KV Cache, Why Generation Is Fast
Every time you chat with ChatGPT, Claude, or Gemini, the model generates its response one token at a time, as you learned in Chapter 17. At each step, the model computes attention over every previous token in the sequence. Without any optimization, this means the model would redo the same enormous matrix multiplications for every token it has already seen, over and over again, for every single new token it generates. The KV cache is the optimization that makes this process fast: it stores the intermediate results of attention computation so they never need to be recalculated. Understanding the KV cache is essential for understanding why language models use so much memory, why longer contexts are more expensive, and why first-token latency differs from per-token latency.
The Problem: Redundant Computation in Attention
To understand why the KV cache exists, let us revisit how attention works during text generation. In Chapter 7, you learned that attention computes three things for each token: a query vector (Q), a key vector (K), and a value vector (V). The query asks “what am I looking for?”, the key says “here is what I contain,” and the value says “here is my actual information.” Attention scores are computed by multiplying queries against keys, then using those scores to create a weighted sum of values.
During training, the model sees the entire sequence at once and computes Q, K, and V for all tokens in parallel. But during generation (inference), the model produces tokens one at a time. Here is what happens at each generation step without any caching:
- The model has a sequence of tokens: the original prompt plus all tokens generated so far.
- For every token in the sequence, the model computes K and V vectors by multiplying the token’s hidden state by the weight matrices W_K and W_V.
- For the new token (the one being generated), the model computes a Q vector.
- The Q vector is multiplied against all K vectors to get attention scores.
- The scores are used to compute a weighted sum of all V vectors.
- This produces the output for the new token.
The critical inefficiency is in step 2. Every time the model generates a new token, it recomputes the K and V vectors for every previous token, even though those tokens have not changed. The hidden states of previous tokens are the same as they were in the last step. The weight matrices W_K and W_V are the same. So the K and V vectors for previous tokens are identical to what they were before.
This is pure waste. If the model has generated 500 tokens so far and is producing token 501, it recomputes 500 K vectors and 500 V vectors that are exactly the same as they were when it generated token 500. And it did the same redundant work when generating token 500, and token 499, and every token before that.
The total computation without caching grows quadratically with sequence length. Generating token 1 requires computing K and V for 1 token. Generating token 2 requires computing K and V for 2 tokens. Generating token n requires computing K and V for n tokens. The total work across all steps is 1 + 2 + 3 + … + n = n(n+1)/2, which is O(n^2).
import numpy as np
def attention_without_cache(hidden_states, W_Q, W_K, W_V, new_token_idx):
"""
Attention WITHOUT KV cache: recompute K and V for ALL tokens every step.
This is what we want to avoid.
hidden_states: shape (seq_len, hidden_dim) - all tokens so far
new_token_idx: index of the new token being generated
"""
seq_len = hidden_states.shape[0]
# Recompute K and V for EVERY token (wasteful!)
K = hidden_states @ W_K # shape: (seq_len, head_dim)
V = hidden_states @ W_V # shape: (seq_len, head_dim)
# Compute Q only for the new token
q = hidden_states[new_token_idx] @ W_Q # shape: (head_dim,)
# Attention scores: q dot each key
scores = q @ K.T / np.sqrt(K.shape[1]) # shape: (seq_len,)
weights = np.exp(scores - np.max(scores))
weights = weights / weights.sum()
# Weighted sum of values
output = weights @ V # shape: (head_dim,)
return output
# For a 1000-token sequence, this recomputes 1000 K and 1000 V vectors
# at EVERY generation step, even though only 1 new token was added.The Solution: Cache K and V Vectors
The KV cache is conceptually simple: after computing the K and V vectors for a token, store them in memory. On the next generation step, only compute K and V for the new token, then append them to the cached K and V vectors from previous steps.
import numpy as np
def attention_with_kv_cache(new_hidden_state, W_Q, W_K, W_V, k_cache, v_cache):
"""
Attention WITH KV cache: only compute K and V for the NEW token.
Reuse cached K and V for all previous tokens.
new_hidden_state: shape (hidden_dim,) - just the new token
k_cache: shape (prev_seq_len, head_dim) - cached keys
v_cache: shape (prev_seq_len, head_dim) - cached values
Returns: output vector, updated k_cache, updated v_cache
"""
# Compute K, V only for the new token
new_k = new_hidden_state @ W_K # shape: (head_dim,)
new_v = new_hidden_state @ W_V # shape: (head_dim,)
# Append to cache
k_cache = np.vstack([k_cache, new_k.reshape(1, -1)])
v_cache = np.vstack([v_cache, new_v.reshape(1, -1)])
# Compute Q for the new token
q = new_hidden_state @ W_Q # shape: (head_dim,)
# Attention: q against ALL keys (cached + new)
scores = q @ k_cache.T / np.sqrt(k_cache.shape[1])
weights = np.exp(scores - np.max(scores))
weights = weights / weights.sum()
# Weighted sum of ALL values (cached + new)
output = weights @ v_cache
return output, k_cache, v_cacheWith the KV cache, each generation step only computes one new K vector and one new V vector (for the new token), instead of recomputing K and V for the entire sequence. The attention computation still requires multiplying the query against all cached keys (which grows linearly with sequence length), but the expensive matrix multiplications to produce K and V vectors are done only once per token.
The savings are dramatic. Without the KV cache, generating a 1,000-token response from a 1,000-token prompt requires computing K and V for approximately 1,000 + 1,001 + 1,002 + … + 2,000 = 1,500,500 token-level K/V computations. With the KV cache, it requires exactly 1,000 K/V computations (one per generated token, since the prompt’s K/V vectors are computed once during the initial forward pass). That is a roughly 1,500x reduction in redundant computation.
The Two Phases of Inference: Prefill and Decode
With the KV cache in place, language model inference naturally splits into two distinct phases with very different computational characteristics. Understanding these two phases is essential for understanding why “time to first token” and “tokens per second” are different metrics, and why they behave so differently.
Phase 1: Prefill (Processing the Prompt)
When you send a prompt to a language model, the model first processes the entire prompt in one forward pass. This is the prefill phase (also called the “prompt processing” phase). During prefill:
- All tokens in the prompt are processed in parallel through every transformer layer.
- K and V vectors are computed for every prompt token at every layer.
- These K and V vectors are stored in the KV cache.
- The model produces logits for the next token (the first token of the response).
The prefill phase is compute-bound: the GPU’s arithmetic units (tensor cores) are the bottleneck because the model is doing massive matrix multiplications across all prompt tokens simultaneously. The GPU’s compute capacity is fully utilized, which is efficient.
Phase 2: Decode (Generating the Response)
After the prefill phase produces the first response token, the model enters the decode phase. During decode:
- Only the single new token is processed through the transformer layers.
- K and V vectors are computed only for this one new token.
- The new K and V vectors are appended to the KV cache.
- The new token’s query attends to all cached keys and values.
- The model produces logits for the next token.
- Steps 1 through 5 repeat for each subsequent token.
The decode phase is memory-bandwidth-bound: the bottleneck is not computation but the speed at which the GPU can read the KV cache and model weights from memory (HBM, or High Bandwidth Memory). Each decode step does relatively little arithmetic (processing just one token), but it must read the entire KV cache from memory to compute attention. As the sequence grows longer, the KV cache grows larger, and reading it takes more time.
This fundamental difference between the two phases explains a phenomenon you may have noticed when using language models: there is often a noticeable pause before the first token appears (the prefill phase processing your prompt), and then tokens stream out at a steady rate (the decode phase generating one token at a time).
import numpy as np
def inference_with_kv_cache(model, prompt_tokens, max_new_tokens=100):
"""
Complete inference showing both phases.
"""
# ============================================
# PHASE 1: PREFILL
# Process the entire prompt in one parallel pass.
# This populates the KV cache for all prompt tokens.
# ============================================
kv_cache = {} # {layer_idx: {"k": array, "v": array}}
# Forward pass on all prompt tokens at once (parallel)
hidden_states = model.embed(prompt_tokens) # (prompt_len, hidden_dim)
for layer_idx in range(model.num_layers):
# Compute K, V for ALL prompt tokens (parallel, efficient)
K = hidden_states @ model.layers[layer_idx].W_K # (prompt_len, head_dim)
V = hidden_states @ model.layers[layer_idx].W_V # (prompt_len, head_dim)
# Store in KV cache
kv_cache[layer_idx] = {"k": K, "v": V}
# Compute attention and FFN (details omitted for clarity)
hidden_states = model.layers[layer_idx].forward(hidden_states, K, V)
# Get logits for the first response token
logits = model.output_head(hidden_states[-1])
first_token = sample(logits)
generated = [first_token]
# ============================================
# PHASE 2: DECODE
# Generate tokens one at a time.
# Only compute K, V for the new token; reuse cache for all others.
# ============================================
for step in range(max_new_tokens - 1):
# Process only the single new token
hidden = model.embed([generated[-1]]) # (1, hidden_dim)
for layer_idx in range(model.num_layers):
# Compute K, V for ONLY the new token
new_k = hidden @ model.layers[layer_idx].W_K # (1, head_dim)
new_v = hidden @ model.layers[layer_idx].W_V # (1, head_dim)
# Append to cache
kv_cache[layer_idx]["k"] = np.vstack([
kv_cache[layer_idx]["k"], new_k
])
kv_cache[layer_idx]["v"] = np.vstack([
kv_cache[layer_idx]["v"], new_v
])
# Attention: new token's Q against ALL cached K, V
hidden = model.layers[layer_idx].forward_single(
hidden, kv_cache[layer_idx]["k"], kv_cache[layer_idx]["v"]
)
logits = model.output_head(hidden[0])
next_token = sample(logits)
if next_token in model.stop_token_ids:
break
generated.append(next_token)
return generatedTTFT vs. TPS: Two Different Latency Metrics
The two-phase structure of inference gives rise to two distinct latency metrics that are commonly reported for language model APIs:
Time to First Token (TTFT) measures how long it takes from when you send your prompt until the first response token appears. This is dominated by the prefill phase: the model must process your entire prompt before it can generate the first token. A longer prompt means a longer TTFT. For a short prompt (a few hundred tokens), TTFT might be 100 to 300 milliseconds. For a very long prompt (100,000+ tokens), TTFT can be several seconds or even tens of seconds.
Tokens Per Second (TPS), also called throughput or inter-token latency, measures how fast tokens are generated during the decode phase. This is relatively constant regardless of prompt length (though it does slow down slightly as the sequence grows because the KV cache gets larger). Typical values for frontier models served via API range from 30 to 150 tokens per second, depending on the model size and hardware.
| Metric | Phase | Bottleneck | Affected By |
|---|---|---|---|
| TTFT | Prefill | Compute (GPU arithmetic) | Prompt length, model size |
| TPS | Decode | Memory bandwidth (reading KV cache) | KV cache size, model size, GPU memory bandwidth |
This is why you sometimes experience a long pause before the response starts (high TTFT due to a long prompt), followed by fast streaming (high TPS during decode). The two metrics are largely independent: you can have fast TPS with slow TTFT (long prompt, fast hardware) or slow TPS with fast TTFT (short prompt, memory-bandwidth-limited hardware).
How Much Memory Does the KV Cache Use?
The KV cache stores two vectors (one key, one value) per token, per layer, per KV head. The formula for KV cache memory per token is:
bytes_per_token = 2 (K and V) * num_layers * num_kv_heads * head_dim * bytes_per_elementThe factor of 2 accounts for storing both a key vector and a value vector. The num_kv_heads is the number of key-value heads (which, as you learned in Chapter 8, can be fewer than the number of query heads when using Grouped Query Attention). The head_dim is the dimension of each head (typically 128 in modern models). The bytes_per_element depends on the numerical precision: 2 bytes for float16/bfloat16, 1 byte for FP8/INT8, or 0.5 bytes for INT4.
Let us compute the KV cache size for several real models.
LLaMA 3.1 405B
LLaMA 3.1 405B is Meta’s largest dense model, with the following architecture (from the HuggingFace config.json):
num_hidden_layers: 126num_key_value_heads: 16 (GQA with 128 query heads, 16 KV heads)head_dim: 128 (hidden_size 16,384 / num_attention_heads 128)max_position_embeddings: 131,072 (128K context window)
import numpy as np
def kv_cache_size(num_layers, num_kv_heads, head_dim, seq_len,
batch_size=1, bytes_per_element=2):
"""
Calculate KV cache memory in bytes.
bytes_per_element: 2 for float16/bfloat16, 1 for FP8/INT8, 0.5 for INT4
"""
bytes_per_token = 2 * num_layers * num_kv_heads * head_dim * bytes_per_element
total_bytes = bytes_per_token * seq_len * batch_size
return total_bytes
def format_bytes(b):
"""Format bytes as human-readable string."""
if b >= 1024**3:
return f"{b / 1024**3:.1f} GB"
elif b >= 1024**2:
return f"{b / 1024**2:.1f} MB"
else:
return f"{b / 1024:.1f} KB"
# LLaMA 3.1 405B
print("=== LLaMA 3.1 405B KV Cache (bfloat16) ===")
print(f"Architecture: 126 layers, 16 KV heads, head_dim=128")
bytes_per_token_405b = 2 * 126 * 16 * 128 * 2 # 2 bytes for bfloat16
print(f"Bytes per token: {bytes_per_token_405b:,} ({format_bytes(bytes_per_token_405b)}/token)")
for seq_len in [1_000, 8_000, 32_000, 128_000]:
total = kv_cache_size(126, 16, 128, seq_len)
print(f" {seq_len:>7,} tokens: {format_bytes(total)}")
print()
# LLaMA 3 70B
print("=== LLaMA 3 70B KV Cache (bfloat16) ===")
print(f"Architecture: 80 layers, 8 KV heads, head_dim=128")
bytes_per_token_70b = 2 * 80 * 8 * 128 * 2
print(f"Bytes per token: {bytes_per_token_70b:,} ({format_bytes(bytes_per_token_70b)}/token)")
for seq_len in [1_000, 8_000, 32_000, 128_000]:
total = kv_cache_size(80, 8, 128, seq_len)
print(f" {seq_len:>7,} tokens: {format_bytes(total)}")
print()
# LLaMA 3.1 405B at full context
full_context = kv_cache_size(126, 16, 128, 131_072)
print(f"LLaMA 3.1 405B at full 128K context: {format_bytes(full_context)}")
print(f"Model weights (bfloat16): ~820 GB")
print(f"KV cache is {full_context / (820 * 1024**3) * 100:.0f}% of model weight size")Here is a comparison table showing KV cache sizes for several models at different sequence lengths, all in bfloat16:
| Model | Layers | KV Heads | Cache/Token | 8K Tokens | 32K Tokens | 128K Tokens |
|---|---|---|---|---|---|---|
| LLaMA 3 8B | 32 | 8 | 128 KB | 1.0 GB | 3.9 GB | 15.6 GB |
| LLaMA 3 70B | 80 | 8 | 320 KB | 2.4 GB | 9.8 GB | 39.1 GB |
| LLaMA 3.1 405B | 126 | 16 | 1,008 KB | 7.7 GB | 30.8 GB | 123.0 GB |
These numbers reveal a critical insight: for long-context inference, the KV cache can consume a significant fraction of total memory. LLaMA 3.1 405B’s weights are approximately 820 GB in bfloat16 (per the safetensors index metadata: 820,162,494,464 bytes). At the full 131,072-token context window, the KV cache for a single request adds another 126 GB. With a batch of 8 concurrent requests at full context, the KV cache alone would require over 1 TB of memory.
This is why the KV cache, not the model weights, is often the dominant memory cost in production serving. The model weights are fixed and shared across all requests. The KV cache is per-request and grows with sequence length.
Source: LLaMA 3.1 405B config.json from HuggingFace (huggingface.co/meta-llama/Llama-3.1-405B): num_hidden_layers=126, num_attention_heads=128, num_key_value_heads=16, hidden_size=16384, max_position_embeddings=131072. LLaMA 3 70B: 80 layers, 64 query heads, 8 KV heads, head_dim=128 (emergentmind.com, apxml.com). Frank Denneman, “The Dynamic World of LLM Runtime Memory,” January 12, 2026 (frankdenneman.nl).
How GQA Reduces KV Cache Size
In Chapter 8, you learned about Grouped Query Attention (GQA), where multiple query heads share a single set of key-value heads. The KV cache benefit of GQA is now concrete: fewer KV heads means fewer K and V vectors to store per token.
Consider what LLaMA 3.1 405B would look like with full Multi-Head Attention (MHA), where every query head has its own KV head:
print("KV Cache: GQA vs Full MHA for LLaMA 3.1 405B at 128K tokens")
print("=" * 60)
# Actual architecture: GQA with 16 KV heads
gqa_cache = kv_cache_size(126, 16, 128, 131_072)
print(f"GQA (16 KV heads): {format_bytes(gqa_cache)}")
# Hypothetical: Full MHA with 128 KV heads
mha_cache = kv_cache_size(126, 128, 128, 131_072)
print(f"MHA (128 KV heads): {format_bytes(mha_cache)}")
print(f"\nGQA reduces KV cache by {mha_cache / gqa_cache:.0f}x")
print(f"Savings: {format_bytes(mha_cache - gqa_cache)}")With full MHA, the KV cache at the full 131,072-token context window would be approximately 1 TB for a single request. GQA’s 8x reduction (128 query heads / 16 KV heads) brings this down to approximately 126 GB. This is the practical reason GQA has become the dominant attention mechanism in modern LLMs: it makes long-context inference feasible by dramatically reducing KV cache memory.
DeepSeek-V3’s MLA: Even More Compression
As covered in Chapter 8, DeepSeek-V3 uses Multi-Head Latent Attention (MLA), which compresses all key-value information into a low-dimensional latent vector of 512 dimensions per token (plus a 64-dimension decoupled RoPE key). Instead of storing separate K and V vectors for each head, MLA stores a single compressed representation that is decompressed on the fly during attention computation.
For DeepSeek-V3 with 61 layers:
# DeepSeek-V3 MLA: cache compressed latent + RoPE key per token per layer
# c_KV dimension: 512, decoupled RoPE key dimension: 64
# Total cached per token per layer: 512 + 64 = 576 values
deepseek_bytes_per_token = 61 * 576 * 2 # 61 layers, 576 values, 2 bytes (bfloat16)
print(f"DeepSeek-V3 MLA bytes per token: {deepseek_bytes_per_token:,} "
f"({format_bytes(deepseek_bytes_per_token)}/token)")
# Compare with a hypothetical GQA equivalent
# DeepSeek-V3 has 128 attention heads; if it used GQA with 16 KV heads:
gqa_equivalent = 2 * 61 * 16 * 128 * 2 # 2 (K+V) * layers * kv_heads * head_dim * bytes
print(f"Hypothetical GQA (16 KV heads) bytes per token: {gqa_equivalent:,} "
f"({format_bytes(gqa_equivalent)}/token)")
print(f"\nMLA compression ratio: {gqa_equivalent / deepseek_bytes_per_token:.1f}x smaller than GQA")MLA achieves roughly 7 to 8x more compression than GQA with 16 KV heads, at the cost of additional computation to decompress the latent vectors during attention. This tradeoff is favorable because the decode phase is memory-bandwidth-bound: reducing the amount of data that must be read from memory speeds up generation even if it requires more arithmetic.
Source: DeepSeek-V3 Technical Report, arXiv:2412.19437, December 2024, Section 2.1.1. MLA compresses KV into a latent vector of dimension d_c=512 plus a decoupled RoPE key of dimension d_h^R=64, totaling 576 values per token per layer.
KV Cache Compression: Making Long Contexts Practical
Even with GQA and MLA, the KV cache remains the primary memory bottleneck for long-context inference. A single LLaMA 3.1 405B request at 128K tokens uses 126 GB of KV cache memory in bfloat16. Serving multiple concurrent users at long context lengths quickly exhausts even the largest GPU clusters. Several compression techniques have been developed to address this problem.
Quantization: Reducing Precision
The most straightforward approach to reducing KV cache size is quantization: storing the cached K and V vectors at lower numerical precision. Instead of using 16-bit floating point (2 bytes per value), you can use 8-bit (1 byte), 4-bit (0.5 bytes), or even 2-bit (0.25 bytes) representations.
KIVI (Liu et al., arXiv:2402.02750, ICML 2024) is a notable KV cache quantization method that achieves 2-bit precision with minimal quality loss. The key insight from KIVI is that keys and values have different statistical distributions and should be quantized differently:
- Keys should be quantized per-channel: group elements along the feature dimension and quantize them together. This is because different channels (dimensions) in the key vectors have very different value ranges.
- Values should be quantized per-token: group all dimensions for a single token and quantize them together. This is because value vectors tend to have more uniform distributions across dimensions but vary across tokens.
With this asymmetric approach, KIVI enables LLaMA-2, Falcon, and Mistral models to maintain near-original quality while using 2.6x less peak memory (including model weights). The memory savings also enable up to 4x larger batch sizes, translating to 2.35x to 3.47x throughput improvement on real inference workloads.
import numpy as np
def demonstrate_kv_cache_quantization():
"""
Show memory savings from KV cache quantization.
"""
# LLaMA 3 70B: 80 layers, 8 KV heads, head_dim=128
layers, kv_heads, head_dim = 80, 8, 128
seq_len = 32_000
print("KV Cache Size for LLaMA 3 70B at 32K tokens")
print("=" * 50)
precisions = [
("BFloat16 (default)", 2),
("FP8 / INT8", 1),
("INT4", 0.5),
("INT2 (KIVI)", 0.25),
]
for name, bytes_per_val in precisions:
total = 2 * layers * kv_heads * head_dim * seq_len * bytes_per_val
total_gb = total / (1024**3)
print(f" {name:<22} {total_gb:>6.1f} GB")
print()
print("Quantizing from BFloat16 to INT4 cuts KV cache by 4x.")
print("Quantizing to INT2 (KIVI) cuts it by 8x.")
demonstrate_kv_cache_quantization()In practice, FP8 and INT8 KV cache quantization have become widely adopted because they halve memory usage with negligible quality degradation. The open-source inference framework llama.cpp supports KV cache quantization via the --cache-type-k and --cache-type-v flags, and Ollama (which uses llama.cpp internally) supports Q8_0 and Q4_0 KV cache quantization via the OLLAMA_KV_CACHE_TYPE environment variable. Running Q8_0 KV cache quantization effectively halves the VRAM required for the context compared to the default FP16, while Q4_0 reduces it to roughly one-third. As a concrete example, running LLaMA 3.2 8B at its full 128K context window with Q4_K_M model quantization requires 23.3 GB without KV cache quantization, 17.0 GB with Q8_0 KV cache quantization, and 13.8 GB with Q4_0 KV cache quantization. For production serving, vLLM supports FP8 KV cache quantization via the kv_cache_dtype="fp8" parameter on NVIDIA Hopper and Ada Lovelace GPUs (H100, L40S, RTX 4090), halving KV cache memory with no calibration data required.
Source: Liu et al., “KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache,” arXiv:2402.02750, February 2024. ICML 2024 (proceedings.mlr.press/v235/liu24bz). KIVI achieves 2.6x peak memory reduction on LLaMA-2, Falcon, and Mistral with near-lossless quality; up to 4x larger batch size and 2.35x to 3.47x throughput on real workloads. Sam McLeod, “Bringing K/V Context Quantisation to Ollama,” December 4, 2024 (smcleod.net): Q8_0 halves VRAM for context; Q4_0 reduces to one-third. Mitja Martini, “K/V Cache Quantization in Ollama,” May 10, 2025 (mitjamartini.com): LLaMA 3.2 8B at 128K context: 23.3 GB (F16), 17.0 GB (Q8_0), 13.8 GB (Q4_0). vLLM FP8 KV cache: supported on Hopper and Ada Lovelace GPUs via kv_cache_dtype="fp8" (discuss.vllm.ai, vllm documentation).
Eviction: Dropping Unimportant Tokens
Instead of keeping all tokens in the KV cache at reduced precision, another approach is to selectively evict (remove) tokens that are unlikely to be important for future attention computations. The idea is that not all tokens in a long sequence are equally important: some tokens receive very little attention and can be safely removed without significantly affecting output quality.
H2O (Heavy-Hitter Oracle) by Zhang et al. (arXiv:2306.14048, NeurIPS 2023) is a foundational KV cache eviction method. The key observation behind H2O is that a small fraction of tokens, called “heavy hitters,” consistently receive high attention scores across many generation steps. These heavy hitters, combined with recent tokens (which are important for local coherence), account for most of the useful information in the KV cache.
H2O’s eviction policy is simple: maintain a fixed-size KV cache budget, and at each step, keep the heavy-hitter tokens (those with the highest cumulative attention scores) plus the most recent tokens. Evict everything else. With a budget of just 20% of the full cache (keeping only 20% of tokens), H2O improved throughput by up to 29x over DeepSpeed Zero-Inference and Hugging Face Accelerate, and up to 3x over FlexGen, on OPT-6.7B and OPT-30B, while maintaining comparable generation quality.
import numpy as np
def h2o_eviction(k_cache, v_cache, attention_scores_history,
budget_ratio=0.2, recent_ratio=0.5):
"""
Simplified H2O eviction: keep heavy hitters + recent tokens.
budget_ratio: fraction of total tokens to keep (e.g., 0.2 = 20%)
recent_ratio: fraction of budget allocated to recent tokens
"""
seq_len = k_cache.shape[0]
budget = int(seq_len * budget_ratio)
num_recent = int(budget * recent_ratio)
num_heavy_hitters = budget - num_recent
# Recent tokens: always keep the most recent ones
recent_indices = list(range(seq_len - num_recent, seq_len))
# Heavy hitters: tokens with highest cumulative attention scores
# (excluding the recent tokens we already kept)
cumulative_scores = np.sum(attention_scores_history, axis=0)
candidate_indices = list(range(seq_len - num_recent))
heavy_hitter_indices = sorted(
candidate_indices,
key=lambda i: cumulative_scores[i],
reverse=True
)[:num_heavy_hitters]
# Combine and sort
keep_indices = sorted(set(heavy_hitter_indices + recent_indices))
# Evict everything else
k_cache_evicted = k_cache[keep_indices]
v_cache_evicted = v_cache[keep_indices]
return k_cache_evicted, v_cache_evicted, keep_indices
# Example: 10,000 tokens, keep only 20% (2,000 tokens)
seq_len = 10_000
budget = int(seq_len * 0.2)
print(f"Original KV cache: {seq_len:,} tokens")
print(f"After H2O eviction (20% budget): {budget:,} tokens")
print(f"Memory savings: {(1 - budget/seq_len)*100:.0f}%")Source: Zhang et al., “H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models,” arXiv:2306.14048, June 2023. NeurIPS 2023. H2O with 20% heavy hitters improved throughput by up to 29x over DeepSpeed Zero-Inference and Hugging Face Accelerate, and up to 3x over FlexGen, on OPT-6.7B and OPT-30B.
Attention Sinks: Why the First Tokens Matter
A related discovery is the attention sink phenomenon, identified by Xiao et al. in the StreamingLLM paper (arXiv:2309.17453, ICLR 2024). The researchers found that in autoregressive language models, the very first tokens in a sequence consistently receive disproportionately high attention scores from all later tokens, even when those initial tokens carry no special semantic meaning.
This happens because of how softmax works in attention. The softmax function must distribute attention weights across all tokens, and the weights must sum to 1. When the model has no strong reason to attend to any particular token, it needs somewhere to “dump” the excess attention probability. The first few tokens become this dumping ground, acting as “attention sinks.”
The practical implication is important for KV cache eviction: if you evict the first few tokens from the cache, model quality degrades catastrophically, even though those tokens seem semantically unimportant. StreamingLLM’s solution is simple: always keep the first few tokens (the attention sinks) plus a sliding window of recent tokens. This enables stable generation over sequences of 4 million tokens or more with a fixed-size KV cache, without any fine-tuning. In streaming settings, StreamingLLM outperforms the sliding window recomputation baseline by up to 22.2x speedup.
def streaming_llm_cache(k_cache, v_cache, window_size=4096, num_sink_tokens=4):
"""
StreamingLLM: keep attention sink tokens + sliding window of recent tokens.
This enables infinite-length generation with fixed memory.
"""
seq_len = k_cache.shape[0]
if seq_len <= window_size + num_sink_tokens:
# Cache fits within budget, no eviction needed
return k_cache, v_cache
# Always keep the first few tokens (attention sinks)
sink_k = k_cache[:num_sink_tokens]
sink_v = v_cache[:num_sink_tokens]
# Keep the most recent tokens (sliding window)
recent_k = k_cache[-(window_size):]
recent_v = v_cache[-(window_size):]
# Combine: sinks + recent window
k_evicted = np.vstack([sink_k, recent_k])
v_evicted = np.vstack([sink_v, recent_v])
return k_evicted, v_evicted
# Example: processing a very long document
total_tokens = 1_000_000
window_size = 4096
sink_tokens = 4
cache_size = window_size + sink_tokens
print(f"StreamingLLM with window_size={window_size}, sink_tokens={sink_tokens}")
print(f"Total sequence: {total_tokens:,} tokens")
print(f"KV cache size: {cache_size:,} tokens (fixed)")
print(f"Memory savings: {(1 - cache_size/total_tokens)*100:.1f}%")Source: Xiao et al., “Efficient Streaming Language Models with Attention Sinks,” arXiv:2309.17453, September 2023. ICLR 2024. StreamingLLM enables stable generation over 4 million+ tokens with a fixed-size KV cache by preserving attention sink tokens. Up to 22.2x speedup over sliding window recomputation in streaming settings (confirmed from ICLR 2024 poster and OpenReview).
Cross-Layer Sharing: Reusing Cache Across Layers
All the techniques above compress the KV cache within each layer. A different approach is to share KV cache entries across layers. Cross-Layer Attention (CLA), proposed by Brandon et al. (arXiv:2405.12981, NeurIPS 2024), observes that adjacent transformer layers often compute very similar key and value vectors. Instead of storing separate KV caches for every layer, CLA allows pairs of adjacent layers to share a single KV cache, cutting the total cache size by an additional 2x on top of whatever per-layer compression is already in use.
The key finding is that CLA can reduce the KV cache by another 2x while maintaining nearly the same accuracy as unmodified Multi-Query Attention. This is orthogonal to GQA and quantization: you can combine cross-layer sharing with GQA (fewer KV heads per layer) and quantization (fewer bits per value) for compounding savings. MiniCache (arXiv:2405.14366, NeurIPS 2024) takes a similar approach, merging KV cache states between adjacent layers in the middle-to-deep portion of the model by interpolating the direction components of state vectors while preserving their magnitudes.
These cross-layer techniques are still primarily research contributions as of March 2026, but they point toward a future where KV cache compression operates along multiple dimensions simultaneously: fewer heads (GQA/MLA), fewer bits (quantization), fewer tokens (eviction), and fewer layers (cross-layer sharing).
Source: Brandon et al., “Reducing Transformer Key-Value Cache Size with Cross-Layer Attention,” arXiv:2405.12981, May 2024. NeurIPS 2024. CLA achieves 2x KV cache reduction with near-MQA accuracy. Liu et al., “MiniCache: KV Cache Compression in Depth Dimension for Large Language Models,” arXiv:2405.14366, May 2024. NeurIPS 2024.
PagedAttention: Solving Memory Fragmentation
Even when the KV cache fits in GPU memory, there is a subtler problem: memory fragmentation. In a serving system handling many concurrent requests, each request has a KV cache that grows dynamically as tokens are generated. The maximum sequence length is unknown in advance (the model might generate 10 tokens or 10,000), so the system must either pre-allocate the maximum possible memory for each request (wasting memory on short responses) or dynamically resize allocations (causing fragmentation).
PagedAttention, introduced by Kwon et al. in the vLLM system (arXiv:2309.06180, SOSP 2023), solves this problem by borrowing a concept from operating systems: virtual memory with paging. Instead of allocating one contiguous block of memory for each request’s KV cache, PagedAttention divides the KV cache into fixed-size pages (also called blocks). Each page holds the K and V vectors for a fixed number of tokens (e.g., 16 tokens per page).
When a request needs more KV cache space, the system allocates a new page from a pool of free pages. Pages do not need to be contiguous in physical memory; a logical-to-physical page table maps each request’s sequence positions to their actual memory locations. This is exactly how modern operating systems manage RAM for processes.
class PagedKVCache:
"""
Simplified PagedAttention KV cache manager.
Inspired by vLLM's approach (Kwon et al., SOSP 2023).
"""
def __init__(self, page_size=16, num_layers=80, num_kv_heads=8,
head_dim=128, total_pages=1000):
self.page_size = page_size
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
# Pre-allocate a pool of pages
# Each page: (page_size, num_kv_heads, head_dim) for K and V
self.k_pages = np.zeros(
(total_pages, page_size, num_kv_heads, head_dim), dtype=np.float16
)
self.v_pages = np.zeros(
(total_pages, page_size, num_kv_heads, head_dim), dtype=np.float16
)
# Track which pages are free
self.free_pages = list(range(total_pages))
# Per-request page tables: request_id -> list of page indices
self.page_tables = {}
def allocate_page(self, request_id):
"""Allocate a new page for a request."""
if not self.free_pages:
raise MemoryError("No free pages available")
page_idx = self.free_pages.pop(0)
if request_id not in self.page_tables:
self.page_tables[request_id] = []
self.page_tables[request_id].append(page_idx)
return page_idx
def free_request(self, request_id):
"""Free all pages for a completed request."""
if request_id in self.page_tables:
self.free_pages.extend(self.page_tables[request_id])
del self.page_tables[request_id]
def get_cache_for_request(self, request_id):
"""Retrieve the full KV cache for a request (across all its pages)."""
pages = self.page_tables.get(request_id, [])
if not pages:
return None, None
k_list = [self.k_pages[p] for p in pages]
v_list = [self.v_pages[p] for p in pages]
return np.concatenate(k_list, axis=0), np.concatenate(v_list, axis=0)
# Example: serving multiple requests with different lengths
cache = PagedKVCache(page_size=16, total_pages=100)
# Request A: 50 tokens (needs 4 pages: ceil(50/16) = 4)
for _ in range(4):
cache.allocate_page("request_A")
# Request B: 200 tokens (needs 13 pages: ceil(200/16) = 13)
for _ in range(13):
cache.allocate_page("request_B")
# Request A finishes, its pages are freed
cache.free_request("request_A")
# Request C can now reuse those freed pages
for _ in range(3):
cache.allocate_page("request_C")
print(f"Pages allocated: Request B={len(cache.page_tables['request_B'])}, "
f"Request C={len(cache.page_tables['request_C'])}")
print(f"Free pages remaining: {len(cache.free_pages)}")
print(f"Total pages: 100")The benefits of PagedAttention are significant:
- Near-zero memory waste: Pages are allocated on demand and freed immediately when a request completes. No memory is wasted on pre-allocated but unused space.
- No fragmentation: Because pages are fixed-size and non-contiguous, there is no fragmentation problem. Any free page can be used by any request.
- Cache sharing: Multiple requests with the same prefix (e.g., the same system prompt) can share KV cache pages through copy-on-write, similar to how operating systems share memory pages between forked processes.
The vLLM paper demonstrated that PagedAttention improved serving throughput by 2 to 4x compared to existing systems like FasterTransformer and Orca, with the improvement being more pronounced for longer sequences and larger models.
Source: Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention,” arXiv:2309.06180, September 2023. SOSP 2023. vLLM achieved 2-4x throughput improvement over FasterTransformer and Orca.
Putting It All Together: KV Cache in the Full Generation Pipeline
Let us trace through a complete example to see how the KV cache fits into the generation pipeline from Chapter 17. We will use LLaMA 3 70B as our example model.
Setup:
- Model: LLaMA 3 70B (80 layers, 8 KV heads, head_dim=128)
- Prompt: 2,000 tokens
- Generated response: 500 tokens
- Precision: bfloat16 (2 bytes per value)
import numpy as np
# Model configuration
num_layers = 80
num_kv_heads = 8
head_dim = 128
bytes_per_value = 2 # bfloat16
# Sequence lengths
prompt_len = 2_000
response_len = 500
total_len = prompt_len + response_len
# KV cache size calculations
bytes_per_token = 2 * num_layers * num_kv_heads * head_dim * bytes_per_value
print("=== KV Cache Through the Generation Pipeline ===")
print(f"Model: LLaMA 3 70B (80 layers, 8 KV heads, head_dim=128)")
print(f"KV cache per token: {bytes_per_token:,} bytes ({bytes_per_token/1024:.0f} KB)")
print()
# Phase 1: Prefill
prefill_cache = bytes_per_token * prompt_len
print(f"PHASE 1: PREFILL (processing {prompt_len:,}-token prompt)")
print(f" KV cache after prefill: {prefill_cache / 1024**2:.1f} MB")
print(f" K vectors stored: {prompt_len} per layer x {num_layers} layers "
f"x {num_kv_heads} heads = {prompt_len * num_layers * num_kv_heads:,}")
print(f" V vectors stored: same = {prompt_len * num_layers * num_kv_heads:,}")
print(f" Total vectors cached: {2 * prompt_len * num_layers * num_kv_heads:,}")
print()
# Phase 2: Decode (token by token)
print(f"PHASE 2: DECODE (generating {response_len} tokens)")
print(f" {'Step':>6} {'Seq Len':>8} {'KV Cache':>12} {'New K/V computed':>18}")
print(f" {'-'*48}")
for step in [1, 10, 100, 250, 500]:
current_len = prompt_len + step
cache_size = bytes_per_token * current_len
print(f" {step:>6} {current_len:>8,} {cache_size/1024**2:>10.1f} MB "
f"{'1 per layer (cached rest)':>18}")
print()
final_cache = bytes_per_token * total_len
print(f"Final KV cache size: {final_cache / 1024**2:.1f} MB")
print(f"Growth during decode: {(final_cache - prefill_cache) / 1024**2:.1f} MB "
f"({response_len} tokens x {bytes_per_token/1024:.0f} KB/token)")
print()
# Without KV cache comparison
without_cache_ops = sum(range(prompt_len + 1, total_len + 1))
with_cache_ops = response_len # Only 1 K/V computation per step
print(f"K/V computations WITHOUT cache: {without_cache_ops:,}")
print(f"K/V computations WITH cache: {with_cache_ops:,}")
print(f"Reduction: {without_cache_ops / with_cache_ops:.0f}x fewer K/V computations")This trace shows the key dynamics:
- Prefill creates the initial KV cache by processing all 2,000 prompt tokens in parallel. This is the compute-intensive phase.
- Decode adds exactly one token’s worth of K/V vectors to the cache at each step. The cache grows linearly, from 625 MB to 781 MB over 500 steps.
- Without the KV cache, the model would perform over 1.1 million K/V computations. With the cache, it performs only 500. That is a 2,250x reduction.
Why First-Token Latency Differs from Per-Token Latency
We can now fully explain a phenomenon that users of language model APIs encounter every day: the first token takes noticeably longer to appear than subsequent tokens. This is a direct consequence of the two-phase inference architecture.
The Math Behind the Latency Difference
Consider a request with a 4,000-token prompt to LLaMA 3 70B:
First token (TTFT): The model must process all 4,000 prompt tokens through 80 transformer layers. At each layer, it computes Q, K, and V for all 4,000 tokens, runs attention (4,000 x 4,000 score matrix), and runs the feed-forward network on all 4,000 positions. This is a massive parallel computation that fully utilizes the GPU’s arithmetic units.
Subsequent tokens (decode): The model processes just 1 token through 80 layers. At each layer, it computes Q, K, and V for 1 token, runs attention (1 x N score vector, where N is the current sequence length), and runs the FFN on 1 position. The arithmetic is minimal, but the model must read the entire KV cache from GPU memory to compute the attention scores.
The ratio between these two phases depends on the prompt length and hardware. For a 4,000-token prompt on modern hardware, TTFT might be 200 to 500 milliseconds, while each subsequent token takes 10 to 30 milliseconds. For a 100,000-token prompt, TTFT could be 5 to 15 seconds, while per-token latency remains roughly the same (slightly slower because the KV cache is larger and takes longer to read).
def estimate_latency(prompt_len, response_len, model_params):
"""
Rough latency estimation showing TTFT vs per-token latency.
These are illustrative estimates, not benchmarks.
"""
# TTFT scales roughly linearly with prompt length
# (parallel processing, but more tokens = more work)
base_ttft_per_1k_tokens = 50 # ms per 1K prompt tokens (rough estimate)
ttft_ms = prompt_len / 1000 * base_ttft_per_1k_tokens
# Per-token latency is relatively constant but grows slightly
# with sequence length (larger KV cache to read)
base_per_token = 15 # ms per token at short context
# KV cache read time grows linearly with sequence length
avg_seq_len = prompt_len + response_len / 2
kv_overhead_ms = avg_seq_len / 100_000 * 5 # ~5ms extra per 100K tokens
per_token_ms = base_per_token + kv_overhead_ms
total_decode_ms = per_token_ms * response_len
total_ms = ttft_ms + total_decode_ms
return {
"ttft_ms": ttft_ms,
"per_token_ms": per_token_ms,
"tokens_per_second": 1000 / per_token_ms,
"total_ms": total_ms,
}
print("Latency Comparison: Short vs Long Prompts")
print("(Illustrative estimates for a 70B model)")
print("=" * 65)
print(f"{'Prompt':>10} {'Response':>10} {'TTFT':>10} {'Per Token':>10} {'TPS':>8} {'Total':>10}")
print("-" * 65)
scenarios = [
(100, 500),
(1_000, 500),
(4_000, 500),
(32_000, 500),
(128_000, 500),
]
for prompt_len, response_len in scenarios:
est = estimate_latency(prompt_len, response_len, None)
print(f"{prompt_len:>10,} {response_len:>10,} "
f"{est['ttft_ms']:>8.0f}ms {est['per_token_ms']:>8.1f}ms "
f"{est['tokens_per_second']:>6.0f} {est['total_ms']/1000:>8.1f}s")
print()
print("Note: TTFT grows with prompt length (more tokens to process).")
print("Per-token latency stays relatively stable (KV cache read is the bottleneck).")
print("These are rough estimates; actual numbers depend on hardware and serving setup.")Why This Matters for Users
The TTFT vs. TPS distinction has practical implications:
Chatbots and interactive applications care most about TTFT. Users perceive the system as “slow” if there is a long pause before the first word appears, even if the subsequent streaming is fast. This is why API providers optimize prefill throughput and why prompt caching (covered in Chapter 19) is so valuable: it eliminates the prefill phase for repeated prompts.
Batch processing and code generation care more about TPS. If you are generating thousands of responses programmatically, the total time is dominated by the decode phase, and higher TPS means lower cost per token.
Long-context applications (analyzing entire codebases, processing long documents) face a double penalty: high TTFT because the prompt is long, and slightly lower TPS because the KV cache is large and takes longer to read at each step.
KV Cache and Batch Serving: The Real Bottleneck
In production, language model servers do not handle one request at a time. They batch multiple requests together to maximize GPU utilization. This is where the KV cache becomes the dominant constraint on serving capacity.
Consider a server running LLaMA 3 70B on 8 GPUs with a total of 640 GB of HBM (8 x 80 GB H100 GPUs). The model weights in bfloat16 consume approximately 140 GB. That leaves roughly 500 GB for KV caches, activations, and framework overhead. Let us say 400 GB is available for KV caches. (Newer hardware like the NVIDIA B200 with 192 GB HBM3e per GPU, or the B300 with 288 GB per GPU, would increase these budgets substantially, but the fundamental tradeoff between model weights and KV cache memory remains the same.)
def serving_capacity(available_memory_gb, bytes_per_token, avg_seq_len):
"""
Estimate how many concurrent requests can be served.
"""
available_bytes = available_memory_gb * (1024**3)
cache_per_request = bytes_per_token * avg_seq_len
max_requests = int(available_bytes / cache_per_request)
return max_requests
# LLaMA 3 70B
bytes_per_token_70b = 2 * 80 * 8 * 128 * 2 # 327,680 bytes
print("Concurrent Request Capacity: LLaMA 3 70B on 8x H100 (640 GB total)")
print(f"Model weights: ~140 GB | Available for KV cache: ~400 GB")
print(f"KV cache per token: {bytes_per_token_70b / 1024:.0f} KB")
print("=" * 55)
print(f"{'Avg Seq Len':>12} {'Cache/Request':>14} {'Max Concurrent':>16}")
print("-" * 55)
for avg_len in [2_000, 8_000, 32_000, 128_000]:
cache_per_req = bytes_per_token_70b * avg_len
max_reqs = serving_capacity(400, bytes_per_token_70b, avg_len)
print(f"{avg_len:>12,} {cache_per_req / 1024**3:>12.1f} GB {max_reqs:>16,}")
print()
print("At 2K tokens, you can serve ~655 concurrent requests.")
print("At 128K tokens, you can serve only ~10 concurrent requests.")
print("The KV cache is the binding constraint on serving capacity.")This table reveals why long-context inference is so expensive: the number of concurrent requests you can serve drops dramatically as context length increases. At 2,000 tokens per request, you can serve over 600 concurrent users. At 128,000 tokens per request, you can serve roughly 10. The GPU compute is the same in both cases; the difference is entirely due to KV cache memory.
This is also why API providers charge more for longer contexts, and why prompt caching (Chapter 19) provides such significant cost savings: by reusing KV cache entries across requests, the server can serve more concurrent users with the same hardware.
A Complete KV Cache Implementation
Let us put everything together into a complete, runnable implementation that demonstrates the KV cache for a single attention layer. This code shows exactly what happens at each generation step: how K and V vectors are computed, cached, and used.
import numpy as np
np.random.seed(42)
class AttentionLayerWithKVCache:
"""
Single attention layer with KV cache.
Demonstrates the exact mechanics of caching during generation.
"""
def __init__(self, hidden_dim=64, num_heads=4, head_dim=16):
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = head_dim
# Weight matrices (randomly initialized for demonstration)
self.W_Q = np.random.randn(hidden_dim, num_heads * head_dim) * 0.02
self.W_K = np.random.randn(hidden_dim, num_heads * head_dim) * 0.02
self.W_V = np.random.randn(hidden_dim, num_heads * head_dim) * 0.02
self.W_O = np.random.randn(num_heads * head_dim, hidden_dim) * 0.02
# KV cache: starts empty
self.k_cache = None # shape will be (seq_len, num_heads, head_dim)
self.v_cache = None
def reset_cache(self):
self.k_cache = None
self.v_cache = None
def prefill(self, hidden_states):
"""
Phase 1: Process all prompt tokens in parallel.
Populates the KV cache.
hidden_states: (prompt_len, hidden_dim)
Returns: (prompt_len, hidden_dim)
"""
prompt_len = hidden_states.shape[0]
# Compute Q, K, V for all tokens
Q = hidden_states @ self.W_Q # (prompt_len, num_heads * head_dim)
K = hidden_states @ self.W_K
V = hidden_states @ self.W_V
# Reshape for multi-head attention
Q = Q.reshape(prompt_len, self.num_heads, self.head_dim)
K = K.reshape(prompt_len, self.num_heads, self.head_dim)
V = V.reshape(prompt_len, self.num_heads, self.head_dim)
# Store K, V in cache
self.k_cache = K.copy()
self.v_cache = V.copy()
# Compute attention (with causal mask)
output = self._compute_attention(Q, K, V, causal=True)
# Output projection
output = output.reshape(prompt_len, self.num_heads * self.head_dim)
return output @ self.W_O
def decode_step(self, hidden_state):
"""
Phase 2: Process a single new token.
Uses cached K, V for previous tokens; only computes new K, V.
hidden_state: (hidden_dim,) - single token
Returns: (hidden_dim,)
"""
# Compute Q, K, V for the single new token
q = hidden_state @ self.W_Q # (num_heads * head_dim,)
new_k = hidden_state @ self.W_K
new_v = hidden_state @ self.W_V
# Reshape
q = q.reshape(1, self.num_heads, self.head_dim)
new_k = new_k.reshape(1, self.num_heads, self.head_dim)
new_v = new_v.reshape(1, self.num_heads, self.head_dim)
# Append new K, V to cache
self.k_cache = np.concatenate([self.k_cache, new_k], axis=0)
self.v_cache = np.concatenate([self.v_cache, new_v], axis=0)
# Attention: new token's Q against ALL cached K, V
# Q shape: (1, num_heads, head_dim)
# K shape: (seq_len, num_heads, head_dim)
# scores shape: (1, num_heads, seq_len)
scores = np.einsum('qhd,shd->qhs', q, self.k_cache) / np.sqrt(self.head_dim)
weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
# Weighted sum of values
output = np.einsum('qhs,shd->qhd', weights, self.v_cache)
# Output projection
output = output.reshape(self.num_heads * self.head_dim)
return output @ self.W_O
def _compute_attention(self, Q, K, V, causal=True):
"""Multi-head attention with optional causal mask."""
seq_len = Q.shape[0]
# Compute scores: (seq_len, num_heads, seq_len)
scores = np.einsum('qhd,khd->qhk', Q, K) / np.sqrt(self.head_dim)
if causal:
# Apply causal mask: position i can only attend to positions <= i
mask = np.triu(np.ones((seq_len, seq_len)), k=1) * (-1e9)
scores = scores + mask[:, np.newaxis, :]
weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
weights = weights / weights.sum(axis=-1, keepdims=True)
# Weighted sum: (seq_len, num_heads, head_dim)
return np.einsum('qhk,khd->qhd', weights, V)
# Demonstration
layer = AttentionLayerWithKVCache(hidden_dim=64, num_heads=4, head_dim=16)
# Simulate a prompt of 5 tokens
prompt = np.random.randn(5, 64)
print("=== Phase 1: Prefill ===")
output = layer.prefill(prompt)
print(f"Processed {prompt.shape[0]} prompt tokens in parallel")
print(f"KV cache shape: K={layer.k_cache.shape}, V={layer.v_cache.shape}")
print(f"KV cache entries: {layer.k_cache.shape[0]} tokens x "
f"{layer.k_cache.shape[1]} heads x {layer.k_cache.shape[2]} dims")
print(f"\n=== Phase 2: Decode (3 tokens) ===")
for step in range(3):
new_token_hidden = np.random.randn(64)
output = layer.decode_step(new_token_hidden)
print(f"Step {step+1}: Generated token, "
f"KV cache now has {layer.k_cache.shape[0]} entries "
f"(5 prompt + {step+1} generated)")
print(f"\nFinal KV cache shape: K={layer.k_cache.shape}, V={layer.v_cache.shape}")
print(f"Total cached vectors: {2 * layer.k_cache.shape[0] * layer.k_cache.shape[1]:,} "
f"(2 x {layer.k_cache.shape[0]} tokens x {layer.k_cache.shape[1]} heads)")This implementation shows the exact mechanics:
- During prefill, K and V are computed for all 5 prompt tokens and stored in the cache. The cache shape is (5, 4, 16): 5 tokens, 4 heads, 16 dimensions per head.
- During decode, each step computes K and V for just 1 new token and appends them to the cache. After 3 decode steps, the cache has 8 entries (5 prompt + 3 generated).
- At each decode step, the new token’s query attends to all cached keys and values. The attention computation grows linearly with the cache size, but the K/V computation is constant (just 1 token).
Key Takeaways
The KV cache stores the key and value vectors computed during attention so they do not need to be recomputed at every generation step. Without it, generating a sequence of length n would require O(n^2) redundant K/V computations. With it, each step computes K and V for only the single new token.
Language model inference has two distinct phases. The prefill phase processes the entire prompt in parallel, populating the KV cache; it is compute-bound and determines Time to First Token (TTFT). The decode phase generates tokens one at a time, reading the KV cache at each step; it is memory-bandwidth-bound and determines Tokens Per Second (TPS).
The KV cache memory formula is:
2 * num_layers * num_kv_heads * head_dim * bytes_per_elementper token. For LLaMA 3.1 405B (126 layers, 16 KV heads, head_dim=128) in bfloat16, this is approximately 1,008 KB per token. At the full 131,072-token context window, a single request’s KV cache consumes roughly 126 GB.For long-context inference, the KV cache often exceeds the model weights in memory consumption. LLaMA 3 70B’s weights are ~140 GB in bfloat16, but serving 8 concurrent requests at 128K tokens each would require over 310 GB of KV cache memory alone.
Grouped Query Attention (GQA) reduces KV cache size by sharing KV heads across groups of query heads. LLaMA 3.1 405B uses 16 KV heads instead of 128 query heads, reducing the KV cache by 8x compared to full Multi-Head Attention.
Multi-Head Latent Attention (MLA), used by DeepSeek-V3, compresses KV information into a 512-dimensional latent vector plus a 64-dimensional RoPE key per token per layer, achieving roughly 7 to 8x more compression than GQA with 16 KV heads.
KV cache quantization reduces memory by storing cached vectors at lower precision. KIVI (Liu et al., arXiv:2402.02750, ICML 2024) achieves 2-bit quantization with near-lossless quality using asymmetric per-channel key quantization and per-token value quantization, enabling up to 4x larger batch sizes and 2.35x to 3.47x throughput improvement. In practice, Q8_0 quantization halves KV cache memory with negligible quality loss, and is supported by llama.cpp, Ollama, and vLLM (FP8 on Hopper/Ada GPUs).
KV cache eviction selectively removes unimportant tokens from the cache. H2O (Zhang et al., arXiv:2306.14048, NeurIPS 2023) keeps “heavy hitter” tokens (those with high cumulative attention scores) plus recent tokens, achieving up to 29x throughput improvement over DeepSpeed Zero-Inference and Hugging Face Accelerate with just 20% of the full cache.
The attention sink phenomenon (Xiao et al., arXiv:2309.17453, ICLR 2024) shows that the first few tokens in a sequence receive disproportionately high attention regardless of their content. StreamingLLM exploits this by keeping attention sink tokens plus a sliding window, enabling stable generation over 4 million+ tokens with fixed memory and up to 22.2x speedup over sliding window recomputation.
Cross-layer KV cache sharing (Brandon et al., arXiv:2405.12981, NeurIPS 2024) allows adjacent transformer layers to share a single KV cache, reducing cache size by an additional 2x on top of per-layer compression. MiniCache (arXiv:2405.14366, NeurIPS 2024) takes a similar approach by merging KV states between adjacent layers. These techniques are orthogonal to GQA, quantization, and eviction.
PagedAttention (Kwon et al., arXiv:2309.06180, SOSP 2023) solves KV cache memory fragmentation by dividing the cache into fixed-size pages, borrowing virtual memory concepts from operating systems. The vLLM system built on PagedAttention achieved 2 to 4x throughput improvement over prior serving systems.
The KV cache is the binding constraint on serving capacity. At 2,000 tokens per request, a LLaMA 3 70B server on 8x H100 GPUs can handle over 600 concurrent requests. At 128,000 tokens per request, it can handle roughly 10. This is why long-context API calls are more expensive and why prompt caching provides significant cost savings.
What’s Next
You now understand why the KV cache exists, how it works, how much memory it consumes, and the techniques used to compress it. But there is a related optimization that builds directly on the KV cache: prompt caching. When you send the same system prompt or conversation history to a language model repeatedly, the server can reuse the KV cache from a previous request instead of recomputing it. In Chapter 19, we will explore how prompt caching works, what gets cached, and how it reduces both latency and cost in production deployments.