Battre PyTorch avec Rust et 180 lignes d'assembleur

Est-il possible de battre PyTorch sur son propre terrain avec Rust et 180 lignes d'assembleur ? Quand l'inférence est M=1, c'est le cache qui décide.

Le mur initial

En travaillant sur une implémentation de GPT-2 en Rust sur Apple Silicon, le constat est brutal : même optimisée, la version Rust reste plus lente que PyTorch + Accelerate. Les BLAS industrielles sont excellentes, et c'est normal.

Mais les frameworks comme PyTorch sont généralistes. L'inférence token par token est tout sauf un cas général.

Explorer la hiérarchie mémoire

Après des années de travail sur l'optimisation CPU, les instructions SIMD et la hiérarchie mémoire, la question revient toujours : comment améliorer l'inférence des LLM dans des cas très spécifiques, notamment en local ?

La démarche commence simple : un moteur GPT-2 en numpy. Sentir où ça coûte. Observer. Mesurer. Grâce aux travaux de Karpathy et Raschka, c'est un vrai plaisir.

Le problème n'est pas le calcul

Ce qui ressort clairement :

Le pattern K-outer en assembleur

En réécrivant les kernels en assembleur ARM64 NEON f16, avec un pattern "K-outer" spécialisé, le résultat est spectaculaire : 156.9 tokens/s sur GPT-2 Small.

Soit 3x plus rapide que PyTorch sur la même machine.

Les benchmarks sur Apple Silicon M3 :

ModèlePyTorch (tok/s)Rust V3 (tok/s)Speedup
GPT-2 Small53156.93.0x
GPT-2 Medium2259.32.7x
GPT-2 Large1329.12.2x
GPT-2 XL7.716.22.1x

Et ensuite ?

Aujourd'hui, un Qwen3-VL-30B-A3B tourne déjà sur un moteur d'inférence en Rust et assembleur. Et il reste beaucoup de pistes à explorer : Metal, Vulkan, AVX512...

La spécialisation GEMV contre la généralisation GEMM, le f16 natif sur NEON comme arme redoutable pour l'edge : s'il n'y avait qu'une chose à retenir, c'est que comprendre son hardware change complètement le problème.