KV Cache Optimization in LLMs
The KV cache (Key and Value vectors cache) is an essential optimization technique for accelerating inference in large language models. It can reduce inference time by a factor of five, but it introduces significant challenges in terms of memory consumption. This article explores how it works, its limitations, and optimization strategies at the inference engine level, including an original approach leveraging CPU microarchitecture.
Understanding the KV Cache
The Basic LLM Mechanism
LLMs use a transformer to produce hidden states from input tokens. These states are projected into the vocabulary space to generate logits, and only the logit of the last token is used to predict the next one. This process repeats for each newly generated token.
The Role of Attention
In the attention mechanism, generating a new token only requires the query vector of the last token and all previous key and value vectors. The crucial point: the key and value vectors of previous tokens remain constant once computed.
The KV Cache Principle
Rather than recomputing key and value vectors for each previous token at every step, they are stored in a cache. To generate a new token, only the key/value vector of the previous token is computed, and the others are retrieved from the cache, thus reducing computation time.
This explains why models like ChatGPT take longer for the first token (the so-called "time-to-first-token" or TTFT, corresponding to computing the KV cache for the prompt) and are faster afterward.
The Scale of the Memory Problem
For a model like Llama3-70B (80 layers, hidden size of 8k, max 4k tokens), each token occupies approximately 2.5 MB in cache, totaling 10.5 GB for 4k tokens. With multiple simultaneous users, GPU memory is quickly saturated.
Inference Engine Optimizations
It is important to distinguish model-level optimizations (like FlashAttention or Grouped-Query Attention) from inference engine optimizations. The following techniques specifically concern the engine, meaning the management of the KV cache, its memory allocation, and its use to accelerate token generation.
PagedAttention
The engine divides the KV cache into fixed-size pages, dynamically allocated as needed. Unused pages can be offloaded to CPU memory or disk, while only active pages remain in GPU memory. This technique, used notably in vLLM, is ideal for applications with long prompts or multi-user sessions.
RadixAttention
The engine uses a radix tree data structure to organize and quickly access KV vectors, reducing cache lookup costs. This approach improves cache access speed, especially for very long contexts.
Partial Offloading
The engine moves less frequently used parts of the KV cache (for example, vectors from older tokens) to CPU memory or disk, keeping only critical data on the GPU. This technique is supported by frameworks like HuggingFace Accelerate, DeepSpeed-Inference, and FlexGen.
Dynamic Cache Sizing
The engine adjusts the KV cache size in real time based on context length and memory availability, removing obsolete entries. This prevents memory overflows while prioritizing relevant data.
Multi-GPU Distribution
The engine distributes the KV cache across multiple GPUs, synchronizing access to generate tokens in parallel or handle massive contexts. This approach, used in NVIDIA's TensorRT-LLM, scales performance with contexts of hundreds of thousands of tokens.
An Original Approach: Exploiting CPU Microarchitecture
Beyond conventional techniques, a promising optimization path involves rethinking how the inference engine uses hardware, particularly on CPUs.
The Problem with Current Engines
Current inference engines have a relatively straightforward cache management approach:
- Load the model into memory
- For each request: load the tokenizer output, apply vectors to KMV matrices until the next token
- Repeat
At each iteration, model weights are potentially evicted from the CPU cache (L1/L2/L3), generating costly cache misses (100 to 200 cycles per RAM access versus 1 to 2 cycles for a cache access).
Simultaneous Processing of Multiple Requests
The idea is to process 3 or 4 requests simultaneously to amortize the cost of cache evictions. Model weights, which are common to all requests, remain in cache longer when reused for multiple consecutive requests. Over billions of operations, this optimization can prove significant.
Weight Distribution Across CPU Cores
Rather than loading the entire model on each core, we can divide the weights into blocks and assign them to different cores. Each core keeps its portion of the weights in L1/L2 cache, and intermediate data (activations, KV cache) flows from one core to another. This maximizes the use of each core's private cache and reduces contention in the shared L3 cache.
Using AVX512
AVX512 SIMD instructions allow processing vectors from multiple requests simultaneously in a single 512-bit register. Combined with pre-fetching (_mm_prefetch) to anticipate data in cache, these instructions can significantly reduce computation time.
Huge Pages for the TLB
On Linux, using 1 GB huge pages (via the hugepagesz=1G kernel boot option) can drastically reduce TLB cache evictions. For a 10.5 GB model, we go from 2.6 million TLB entries (4 KB pages) to only 10-11 entries (1 GB pages), virtually eliminating TLB misses.
Linux kernel configuration:
GRUB_CMDLINE_LINUX_DEFAULT="quiet splash hugepagesz=1G hugepages=10"
In Rust, allocation is done via mmap with the MAP_HUGETLB flag:
use libc::{mmap, MAP_HUGETLB, PROT_READ, PROT_WRITE, MAP_PRIVATE, MAP_ANONYMOUS}; let size = 1 << 30; // 1 GB let ptr = unsafe { mmap( std::ptr::null_mut(), size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB, -1, 0, ) };
Measuring the Impact
To validate these optimizations, several metrics are essential:
- Cache misses:
perf stat -e cache-missesto compare with and without batching - TLB misses:
perf stat -e tlb_load_missesto measure the impact of huge pages - TTFT (Time-to-first-token): time before the first generated token
- Throughput: number of tokens generated per second
The ideal approach is to prototype with a reduced model (100M parameters) on a multi-core CPU with AVX512, then extrapolate to larger models.
Conclusion
The KV cache is indispensable for LLM inference, but its optimal management remains an open challenge. Conventional techniques (PagedAttention, offloading, multi-GPU distribution) address memory issues at scale. Fine-grained exploitation of CPU microarchitecture through intelligent batching, weight distribution across cores, AVX512 instructions, and huge pages opens a complementary path for significant gains, particularly on CPU deployments where resources are constrained. These optimizations, coded in Rust and assembly, remain to be validated through rigorous benchmarks, but the potential is real.