Tile-major layout and transformer memory allocations
When writing a CPU inference engine, two questions keep coming back: how are matrices stored in memory, and how much memory is needed overall. This article covers both. First the tile-major layout used for all projection matrices, then the full memory sizing -- model weights, runtime buffers and KV cache -- for my Qwen3 inference engine.
Tile-major layout
The row-major problem
In standard row-major layout, a matrix [N, K] is stored row by row: the K elements of row 0, then row 1, and so on. For an embedding lookup (reading the row corresponding to a token), this is perfect: K contiguous values in memory.
But for a matrix-vector multiplication (matvec), the situation is different. Computing y = W * x requires, for each output element y[i], the dot product of row i of W with vector x. So far so good. But if you want to vectorize by processing 32 rows simultaneously (which is natural with AVX-512 in f16), you need to read the same column across 32 adjacent rows. In row-major, those 32 values are spaced K elements apart -- strided accesses that destroy cache locality.
The tile-major principle
Tile-major reorganizes a matrix [N, K] as [N/32, K, 32]. Concretely:
- The matrix is split into tiles of 32 rows
- Within each tile, data is stored column by column: the 32 values of column 0, then the 32 values of column 1, etc.
- Reading 32 values from the same column (within a tile) is a contiguous 64-byte read (32 f16 values x 2 bytes)
That is exactly one cache line. A single memory access loads the 32 elements needed for a partial dot product across 32 simultaneous rows.
Why this changes everything
In row-major, processing 32 rows in parallel on a column requires 32 memory accesses (one per row, with a stride of K x 2 bytes). In tile-major, it is a single contiguous 64-byte access. For a [1024, 1024] matrix in f16, you go from 32 potential cache misses to just 1 per column processed. Across thousands of columns and dozens of layers, the gain is massive.
Per-layer matrices
Each transformer layer stores the following matrices:
| Matrix | Dimensions | Size | Layout |
|---|---|---|---|
| q_proj | [1024, 1024] | 2 MB | Tile-major |
| k_proj | [512, 1024] | 1 MB | Tile-major |
| v_proj | [512, 1024] | 1 MB | Tile-major |
| o_proj | [1024, 1024] | 2 MB | Tile-major |
| gate_proj | [3072, 1024] | 6 MB | Tile-major |
| up_proj | [3072, 1024] | 6 MB | Tile-major |
| down_proj | [1024, 3072] | 6 MB | Tile-major |
Total per layer: ~24 MB of matrices.
The first four (q_proj, k_proj, v_proj, o_proj) are the attention projections. Q and O have the full model dimension (1024) as output, while K and V project to a reduced dimension (512) thanks to Grouped-Query Attention.
The next three (gate_proj, up_proj, down_proj) form the feed-forward network (FFN). gate_proj and up_proj expand from 1024 to 3072 (3x factor). down_proj brings it back from 3072 to 1024. gate_proj controls the activation via SiLU before multiplication with the up_proj output.
All these matrices are stored in tile-major. No exceptions.
Total memory allocations
Model weights
| Component | Calculation | Size |
|---|---|---|
| Embeddings (row-major) | 151,936 x 1,024 x 2 | 311 MB |
| Embeddings (tile-major) | 151,936 x 1,024 x 2 | 311 MB |
| Final norm | 1,024 x 2 | 2 KB |
| 28 layers | 28 x 24 MB | 672 MB |
| Total | ~1.3 GB |
Why two copies of the embeddings? The embedding matrix serves two distinct operations:
- Embedding lookup (beginning of the model): read the row corresponding to a token. Row-major is optimal -- a sequential read of 1024 f16 values.
- LM head (end of the model): matrix-vector multiplication to project the hidden state to the vocabulary (151,936 classes). Tile-major is optimal -- contiguous accesses in tiles of 32 rows.
Duplicating the matrix costs an extra 311 MB. This is a deliberate memory/performance tradeoff: each layout is optimal for its specific use. The tile-major copy is reorganized from embed_tokens at model loading time.
Runtime buffers
Decode mode (token-by-token generation):
| Buffer | Size |
|---|---|
| Temporary buffers | ~320 KB |
| RoPE cache | ~16 KB (for max_seq_len = 32,768) |
| KV cache | dynamically allocated |
Prefill mode (batch prompt processing):
| Buffer | Size |
|---|---|
| Batch buffers | ~12 MB (for MAX_PREFILL_LEN = 512) |
| KV cache | dynamically allocated |
Prefill buffers are larger because up to 512 tokens are processed simultaneously. In decode, only one token is processed at a time, hence much more compact buffers.
KV cache per layer
The KV cache is allocated in chunks of 256 tokens. No per-token reallocation -- a full chunk is allocated when the previous one is full.
Chunk sizing
A chunk covers 256 positions for all attention heads:
- 8 heads x 64 dimensions x 2 bytes x 256 tokens = 256 KB for K
- 8 heads x 64 dimensions x 2 bytes x 256 tokens = 256 KB for V
- Total per chunk: 512 KB (K + V combined)
Concrete examples
For a sequence of length L:
| Sequence | Chunks | KV cache per layer | Total KV cache (28 layers) |
|---|---|---|---|
| 5 tokens | ceil(5/256) = 1 | 512 KB | 14 MB |
| 8 tokens (5 + 3) | ceil(8/256) = 1 | 512 KB | 14 MB |
| 300 tokens | ceil(300/256) = 2 | 1 MB | 28 MB |
| 1024 tokens | ceil(1024/256) = 4 | 2 MB | 56 MB |
With 5 tokens, a single chunk is allocated. Adding 3 tokens (for a total of 8) stays within the same chunk -- no new allocation. The chunk will only be extended when exceeding 256 tokens.
Why chunks
Chunk-based allocation in groups of 256 tokens avoids reallocating on every generated token. It is a tradeoff between memory fragmentation (a 512 KB chunk for 5 tokens is wasteful) and allocation cost (reallocating per token would be catastrophic). 256 is a good balance point: large enough to amortize allocations, small enough not to waste too much memory at the end of a chunk.
Summary
For an 8-token prompt on Qwen3 (28 layers, 1024 hidden, 8 KV heads x 64 dim):
| Component | Size |
|---|---|
| Model weights | ~1.3 GB |
| Runtime buffers (decode) | ~336 KB |
| KV cache (28 layers x 1 chunk) | 14 MB |
| Total | ~1.3 GB |
The KV cache is negligible for short sequences. It starts to matter on long contexts (thousands of tokens). At 32,768 tokens (the model's limit), it would reach 28 x 128 x 512 KB = 1.75 GB -- almost as much as the model weights themselves.