LLM systems primitives rebuilt from scratch — PyTorch baselines → custom Triton GPU kernels.
This repository reconstructs and optimizes the core computational primitives that every large language model is built from. Not a training framework. Not a Hugging Face wrapper. A systems-level implementation focused on correctness, numerical stability, and GPU memory performance.
The central thesis: modern LLM performance is not about FLOPs. It is about memory bandwidth. Every kernel here is designed around that constraint — fusing operations to minimize global memory round-trips, and benchmarked to prove it.
Benchmarked on a local NVIDIA GPU using triton.testing.do_bench (median latency). All comparisons use fp16.
| Config | PyTorch (fp16) | Triton Flash | Speedup | Flash GB/s |
|---|---|---|---|---|
| B=1 H=1 T=512 D=64 | 0.033 ms | 0.024 ms | 1.35× | 10.8 GB/s |
| B=1 H=1 T=1024 D=64 | 0.065 ms | 0.038 ms | 1.70× | 13.8 GB/s |
| B=1 H=1 T=2048 D=64 | 0.233 ms | 0.092 ms | 2.52× | 11.4 GB/s |
| B=2 H=4 T=512 D=64 | 0.136 ms | 0.050 ms | 2.71× | 41.8 GB/s |
Speedup grows with sequence length — naive attention materialises the full T×T score matrix in HBM (O(T²) memory); FlashAttention tiles it in SRAM (O(T) memory).
| Config | Unfused (PyTorch) | Fused (Triton) | Speedup | Triton GB/s |
|---|---|---|---|---|
| B=2 T=256 O=4096 | 0.333 ms | 0.054 ms | 6.21× | 156.3 GB/s |
| B=2 T=512 O=4096 | 1.073 ms | 0.102 ms | 10.55× | 164.9 GB/s |
| B=2 T=1024 O=4096 | 2.865 ms | 0.200 ms | 14.36× | 168.2 GB/s |
| B=2 T=1024 O=16384 | 11.458 ms | 0.782 ms | 14.65× | 171.7 GB/s |
Unfused reads the GEMM output twice — once for bias, once for GELU. Triton reads it once, applies both in registers, writes once. At large O the kernel saturates memory bandwidth (~172 GB/s peak on this GPU).
| N params | PyTorch AdamW | Triton fused | Speedup | Triton GB/s |
|---|---|---|---|---|
| 1M | 0.237 ms | 0.169 ms | 1.41× | 165.8 GB/s |
| 10M | 5.456 ms | 1.624 ms | 3.36× | 172.5 GB/s |
| 50M | 27.279 ms | 7.906 ms | 3.45× | 177.1 GB/s |
PyTorch AdamW issues 5+ separate CUDA kernels per step. Triton fuses all state updates into one — read w, g, m, v once; write w, m, v once. At 50M params the fused kernel runs at 177 GB/s, close to peak bandwidth.
| Config | PyTorch (naive) | Triton | Speedup | Triton GB/s |
|---|---|---|---|---|
| B=1 H=8 D=128 T=512 | 0.079 ms | 0.037 ms | 2.13× | 56.8 GB/s |
| B=1 H=8 D=128 T=1024 | 0.117 ms | 0.060 ms | 1.94× | 69.5 GB/s |
| B=1 H=8 D=128 T=2048 | 0.203 ms | 0.105 ms | 1.94× | 80.1 GB/s |
| B=2 H=8 D=128 T=1024 | 0.348 ms | 0.088 ms | 3.94× | 95.0 GB/s |
Single-token inference is entirely memory-bandwidth bound — the query vector is tiny, the KV cache is large. Triton's fused online-softmax kernel reads the cache once with no intermediate materialisation.
triton-llm-kernels/
├── requirements.txt
│
├── normalization/ # Transformer normalization layers
│ ├── triton_rmsnorm.py # Fused RMSNorm (LLaMA / Mistral)
│ ├── triton_layernorm.py # Fused LayerNorm (GPT-2 / BERT)
│ ├── triton_residual_rmsnorm.py # Fused residual add + RMSNorm
│ ├── rmsnorm.py / layernorm.py # PyTorch references
│ └── test_*.py
│
├── attention/ # Attention mechanisms
│ ├── flash_attention_triton.py # FlashAttention: multi-head, causal mask
│ ├── blocked_attention.py # Online softmax in Python (reference)
│ ├── naive_attention.py # Naive O(T²) attention (reference)
│ └── test_*.py
│
├── mlp/ # MLP / FFN layers
│ ├── triton_bias_gelu.py # Fused bias + exact GELU
│ ├── gelu.py / linear.py # PyTorch references
│ └── test_*.py
│
├── optimizers/ # Training optimizers
│ ├── triton_adam.py # Fully fused AdamW kernel
│ ├── adam.py # PyTorch reference
│ └── test_*.py
│
├── inference/ # Inference-specific kernels
│ ├── triton_inference_attention.py # Single-token attention + KV cache
│ ├── kv_cache.py # KV cache layout
│ ├── attention_inference*.py # PyTorch references
│ └── test_*.py
│
└── benchmarks/ # Head-to-head benchmarks
├── attention_bench.py
├── mlp_bench.py
├── optimizer_bench.py
├── inference_bench.py
└── utils.py
- RMSNorm — single-pass fused kernel; used in LLaMA, Mistral, Qwen
- LayerNorm — mean + variance in one pass; used in GPT-2, BERT, T5
- Fused Residual RMSNorm — residual add + RMSNorm in one kernel; returns both normalised output
yand pre-norm residual sumufor the next block's skip connection
- Naive attention — O(T²) memory PyTorch reference
- Blocked attention — online softmax in Python; shows algorithmic correctness alone doesn't help without kernel fusion
- FlashAttention (Triton) — tiled O(T) memory; multi-head; optional causal mask for autoregressive generation
- Inference attention (Triton) — single-token decoding against full KV cache; multi-head; online softmax in Triton
- Fused Bias + GELU — bias add and exact GELU in one kernel; eliminates intermediate HBM write; up to 14.65× faster than unfused at large O
- Fused AdamW — weight update, first moment, second moment, and weight decay in one kernel dispatch; up to 3.45× faster than PyTorch at 50M params
- KV cache — pre-allocated
(B, T_max, H, D)layout matching production inference engines - Triton inference attention — attends single query token to full KV cache; multi-head
pip install -r requirements.txtRequires a CUDA-capable NVIDIA GPU (Ampere or newer recommended).
Run all tests:
python normalization/test_rmsnorm.py
python normalization/test_layernorm.py
python normalization/test_residual_rmsnorm.py
python attention/test_flash_attention.py
python attention/test_blocked_attention.py
python mlp/test_triton_bias_gelu.py
python optimizers/test_triton_adam.py
python inference/test_triton_inference_attention.py
python inference/test_kv_cache.pyRun benchmarks:
python benchmarks/attention_bench.py
python benchmarks/mlp_bench.py
python benchmarks/optimizer_bench.py
python benchmarks/inference_bench.pyThe first run of any Triton kernel will be slow (~30s) while Triton JIT-compiles and caches the kernels. Subsequent runs are fast.
FlashAttention kernel validation
-----------------------------------------------------------------
[PASS] B=1 H=1 T= 128 D= 64 | max_diff=4.88e-04
[PASS] B=2 H=4 T= 128 D= 64 | max_diff=4.88e-04
[PASS] B=1 H=1 T= 256 D= 64 | max_diff=4.88e-04
[PASS] B=2 H=4 T= 128 D= 32 | max_diff=4.88e-04
[PASS] B=1 H=1 T= 128 D= 64 causal | max_diff=9.77e-04
[PASS] B=2 H=4 T= 128 D= 64 causal | max_diff=9.77e-04
-----------------------------------------------------------------
6/6 passed ✓
Triton inference attention validation
-------------------------------------------------------
[PASS] B=1 H=1 D= 64 T= 1024 | max_diff=2.67e-05
[PASS] B=2 H=4 D= 64 T= 512 | max_diff=1.11e-04
[PASS] B=1 H=8 D= 128 T= 2048 | max_diff=3.02e-05
[PASS] B=2 H=4 D= 32 T= 256 | max_diff=1.13e-04
-------------------------------------------------------
4/4 passed ✓
AdamW kernel validation
-----------------------------------------------------------------
[PASS] N= 1,024 | diff_w=2.38e-07 diff_m=5.96e-08 diff_v=1.18e-07
[PASS] N= 4,096 | diff_w=2.38e-07 diff_m=8.94e-08 diff_v=2.07e-07
[PASS] N= 10,000,000 | diff_w=4.77e-07 diff_m=1.19e-07 diff_v=3.71e-07
-----------------------------------------------------------------
3/3 passed ✓
Triton Bias+GELU kernel validation
----------------------------------------------------
[PASS] B=2 T= 4 O= 32 | max_diff=1.95e-03
[PASS] B=2 T= 1024 O= 4096 | max_diff=3.91e-03
[PASS] B=4 T= 512 O= 16384 | max_diff=3.91e-03
----------------------------------------------------
3/3 passed ✓
| Kernel | ATOL | Reason |
|---|---|---|
| RMSNorm | 5e-3 | fp16 rounding accumulates with D; single reduction |
| LayerNorm | 1e-2 | Two reductions (mean + variance) compound at large non-pow2 D; matches PyTorch's own fp16 test tolerances |
| Residual RMSNorm | 1e-3 | Fusion doesn't add error — stays in fp32 registers |
| FlashAttention | 1e-2 | Online softmax; matches fp32 reference within fp16 precision |
| Bias + GELU | 1e-2 | fp16 erf; consistent with PyTorch's own fp16 tolerances |
| AdamW | 1e-5 | fp32 throughout; near-exact |
All test scripts exit with code 1 on failure — CI compatible.
1. Memory bandwidth first, FLOPs second. Normalization, elementwise ops, and attention softmax are all bandwidth-bound. Kernel fusion reduces HBM round-trips — that is where the speedup comes from.
2. Correctness before performance. Every Triton kernel is validated against a fp32 PyTorch reference before benchmarking. Tests cover non-power-of-two dimensions that expose common tiling bugs.
3. No hidden abstractions. Strides, tile sizes, and masking are explicit. You can read any kernel and understand exactly what is happening in memory.
4. Fusion only where justified. Residual + RMSNorm is fused because both ops are bandwidth-bound and always co-located in transformer blocks. GEMM is not fused because it is compute-bound — cuBLAS handles it better.
5. Production-compatible interfaces.
triton_residual_rmsnorm returns (y, u) matching flash-attn and vLLM conventions.
triton_inference_attention handles multi-head and arbitrary KV cache lengths.
- FlashAttention: Fast and Memory-Efficient Exact Attention — Dao et al., 2022
- Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations — Tillet et al., 2019
- Root Mean Square Layer Normalization — Zhang & Sennrich, 2019
- LLaMA: Open and Efficient Foundation Language Models — Touvron et al., 2023
- Triton tutorials