Layout tile-major et allocations mémoire d'un transformer
Quand on écrit un moteur d'inférence CPU, deux questions reviennent en permanence : comment sont stockées les matrices en mémoire, et combien de mémoire faut-il au total. Cet article couvre les deux. D'abord le layout tile-major utilisé pour toutes les matrices de projection, puis le dimensionnement complet -- poids du modèle, buffers runtime et KV cache -- pour mon moteur d'inférence Qwen3.
Layout tile-major
Le problème du row-major
En layout row-major classique, une matrice [N, K] est stockée ligne par ligne : les K éléments de la ligne 0, puis ceux de la ligne 1, etc. Pour un embedding lookup (lire la ligne correspondant à un token), c'est parfait : on lit K valeurs contiguës en mémoire.
Mais pour une multiplication matrice-vecteur (matvec), la situation est différente. Le calcul de y = W * x nécessite, pour chaque élément de sortie y[i], le produit scalaire de la ligne i de W avec le vecteur x. Jusque-là, pas de problème. Mais si on veut vectoriser en traitant 32 lignes simultanément (ce qui est naturel avec AVX-512 en f16), il faut lire la même colonne de 32 lignes adjacentes. En row-major, ces 32 valeurs sont espacées de K éléments chacune -- des accès stridés qui détruisent la localité du cache.
Le principe du tile-major
Le layout tile-major réorganise une matrice [N, K] en [N/32, K, 32]. Concrètement :
- On découpe la matrice en tuiles de 32 lignes
- À l'intérieur de chaque tuile, on stocke les données colonne par colonne : les 32 valeurs de la colonne 0, puis les 32 de la colonne 1, etc.
- Lire 32 valeurs de la même colonne (dans une tuile) revient à lire 64 octets contigus (32 valeurs f16 × 2 octets)
C'est exactement une ligne de cache. Un seul accès mémoire charge les 32 éléments nécessaires pour un produit scalaire partiel sur 32 lignes simultanées.
Pourquoi ça change tout
En row-major, traiter 32 lignes en parallèle sur une colonne nécessite 32 accès mémoire (un par ligne, avec un stride de K × 2 octets). En tile-major, c'est un seul accès contigu de 64 octets. Pour une matrice [1024, 1024] en f16, on passe de 32 cache misses potentiels à 1 seul par colonne traitée. Sur des milliers de colonnes et des dizaines de couches, le gain est massif.
Matrices par couche
Chaque couche transformer stocke les matrices suivantes :
| Matrice | Dimensions | Taille | Layout |
|---|---|---|---|
| q_proj | [1024, 1024] | 2 Mo | Tile-major |
| k_proj | [512, 1024] | 1 Mo | Tile-major |
| v_proj | [512, 1024] | 1 Mo | Tile-major |
| o_proj | [1024, 1024] | 2 Mo | Tile-major |
| gate_proj | [3072, 1024] | 6 Mo | Tile-major |
| up_proj | [3072, 1024] | 6 Mo | Tile-major |
| down_proj | [1024, 3072] | 6 Mo | Tile-major |
Total par couche : ~24 Mo de matrices.
Les quatre premières (q_proj, k_proj, v_proj, o_proj) sont les projections de l'attention. Q et O ont la dimension complète du modèle (1024) en sortie, tandis que K et V projettent vers une dimension réduite (512) grâce au Grouped-Query Attention.
Les trois suivantes (gate_proj, up_proj, down_proj) forment le réseau feed-forward (FFN). gate_proj et up_proj expandent de 1024 vers 3072 (facteur 3x). down_proj ramène de 3072 vers 1024. Le gate_proj contrôle l'activation via SiLU avant multiplication avec la sortie de up_proj.
Toutes ces matrices sont stockées en tile-major. Aucune exception.
Allocations mémoire totales
Poids du modèle
| Composant | Calcul | Taille |
|---|---|---|
| Embeddings (row-major) | 151 936 × 1 024 × 2 | 311 Mo |
| Embeddings (tile-major) | 151 936 × 1 024 × 2 | 311 Mo |
| Final norm | 1 024 × 2 | 2 Ko |
| 28 couches | 28 × 24 Mo | 672 Mo |
| Total | ~1,3 Go |
Pourquoi deux copies des embeddings ? La matrice d'embeddings sert à deux opérations distinctes :
- Embedding lookup (début du modèle) : on lit la ligne correspondant à un token. Le row-major est optimal -- une lecture séquentielle de 1024 valeurs f16.
- LM head (fin du modèle) : multiplication matrice-vecteur pour projeter le hidden state vers le vocabulaire (151 936 classes). Le tile-major est optimal -- accès contigus par tuiles de 32 lignes.
Dupliquer la matrice coûte 311 Mo supplémentaires. C'est un compromis mémoire/performance délibéré : chaque layout est optimal pour son usage spécifique. La copie tile-major est réorganisée à partir de embed_tokens au chargement du modèle.
Buffers runtime
Mode decode (génération token par token) :
| Buffer | Taille |
|---|---|
| Buffers temporaires | ~320 Ko |
| Cache RoPE | ~16 Ko (pour max_seq_len = 32 768) |
| KV cache | allocation dynamique |
Mode prefill (traitement du prompt par batch) :
| Buffer | Taille |
|---|---|
| Buffers batch | ~12 Mo (pour MAX_PREFILL_LEN = 512) |
| KV cache | allocation dynamique |
Les buffers de prefill sont plus grands car on traite jusqu'à 512 tokens simultanément. En decode, on ne traite qu'un seul token à la fois, d'où des buffers bien plus compacts.
KV cache par couche
Le KV cache est alloué par chunks de 256 tokens. Pas de réallocation à chaque token -- on alloue un chunk entier quand le précédent est plein.
Dimensionnement d'un chunk
Un chunk couvre 256 positions pour toutes les têtes d'attention :
- 8 têtes × 64 dimensions × 2 octets × 256 tokens = 256 Ko pour K
- 8 têtes × 64 dimensions × 2 octets × 256 tokens = 256 Ko pour V
- Total par chunk : 512 Ko (K + V combinés)
Exemples concrets
Pour une séquence de longueur L :
| Séquence | Chunks | KV cache par couche | KV cache total (28 couches) |
|---|---|---|---|
| 5 tokens | ceil(5/256) = 1 | 512 Ko | 14 Mo |
| 8 tokens (5 + 3) | ceil(8/256) = 1 | 512 Ko | 14 Mo |
| 300 tokens | ceil(300/256) = 2 | 1 Mo | 28 Mo |
| 1024 tokens | ceil(1024/256) = 4 | 2 Mo | 56 Mo |
Avec 5 tokens, un seul chunk est alloué. En ajoutant 3 tokens (pour un total de 8), on reste dans le même chunk -- pas de nouvelle allocation. Le chunk ne sera étendu que lorsqu'on dépassera 256 tokens.
Pourquoi des chunks
L'allocation par chunks de 256 tokens évite la réallocation à chaque token généré. C'est un compromis entre fragmentation mémoire (un chunk de 512 Ko pour 5 tokens, c'est du gaspillage) et coût d'allocation (réallouer à chaque token serait catastrophique). 256 est un bon point d'équilibre : assez grand pour amortir les allocations, assez petit pour ne pas gaspiller trop de mémoire en fin de chunk.
Récapitulatif
Pour un prompt de 8 tokens sur Qwen3 (28 couches, 1024 hidden, 8 KV heads × 64 dim) :
| Composant | Taille |
|---|---|
| Poids du modèle | ~1,3 Go |
| Buffers runtime (decode) | ~336 Ko |
| KV cache (28 couches × 1 chunk) | 14 Mo |
| Total | ~1,3 Go |
Le KV cache est négligeable pour des séquences courtes. C'est sur des contextes longs (milliers de tokens) qu'il commence à peser. À 32 768 tokens (la limite du modèle), il atteindrait 28 × 128 × 512 Ko = 1,75 Go -- presque autant que les poids du modèle.