Skip to content

anviit/triton-llm-kernels

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

triton-llm-kernels

LLM systems primitives rebuilt from scratch — PyTorch baselines → custom Triton GPU kernels.

Python Triton PyTorch CUDA Tests


This repository reconstructs and optimizes the core computational primitives that every large language model is built from. Not a training framework. Not a Hugging Face wrapper. A systems-level implementation focused on correctness, numerical stability, and GPU memory performance.

The central thesis: modern LLM performance is not about FLOPs. It is about memory bandwidth. Every kernel here is designed around that constraint — fusing operations to minimize global memory round-trips, and benchmarked to prove it.


Triton vs PyTorch: Measured Performance

Benchmarked on a local NVIDIA GPU using triton.testing.do_bench (median latency). All comparisons use fp16.

FlashAttention vs Naive Attention

Config PyTorch (fp16) Triton Flash Speedup Flash GB/s
B=1 H=1 T=512 D=64 0.033 ms 0.024 ms 1.35× 10.8 GB/s
B=1 H=1 T=1024 D=64 0.065 ms 0.038 ms 1.70× 13.8 GB/s
B=1 H=1 T=2048 D=64 0.233 ms 0.092 ms 2.52× 11.4 GB/s
B=2 H=4 T=512 D=64 0.136 ms 0.050 ms 2.71× 41.8 GB/s

Speedup grows with sequence length — naive attention materialises the full T×T score matrix in HBM (O(T²) memory); FlashAttention tiles it in SRAM (O(T) memory).


Fused Bias + GELU vs Unfused

Config Unfused (PyTorch) Fused (Triton) Speedup Triton GB/s
B=2 T=256 O=4096 0.333 ms 0.054 ms 6.21× 156.3 GB/s
B=2 T=512 O=4096 1.073 ms 0.102 ms 10.55× 164.9 GB/s
B=2 T=1024 O=4096 2.865 ms 0.200 ms 14.36× 168.2 GB/s
B=2 T=1024 O=16384 11.458 ms 0.782 ms 14.65× 171.7 GB/s

Unfused reads the GEMM output twice — once for bias, once for GELU. Triton reads it once, applies both in registers, writes once. At large O the kernel saturates memory bandwidth (~172 GB/s peak on this GPU).


Fused AdamW vs PyTorch AdamW

N params PyTorch AdamW Triton fused Speedup Triton GB/s
1M 0.237 ms 0.169 ms 1.41× 165.8 GB/s
10M 5.456 ms 1.624 ms 3.36× 172.5 GB/s
50M 27.279 ms 7.906 ms 3.45× 177.1 GB/s

PyTorch AdamW issues 5+ separate CUDA kernels per step. Triton fuses all state updates into one — read w, g, m, v once; write w, m, v once. At 50M params the fused kernel runs at 177 GB/s, close to peak bandwidth.


Inference Attention (single-token decoding) vs Naive

Config PyTorch (naive) Triton Speedup Triton GB/s
B=1 H=8 D=128 T=512 0.079 ms 0.037 ms 2.13× 56.8 GB/s
B=1 H=8 D=128 T=1024 0.117 ms 0.060 ms 1.94× 69.5 GB/s
B=1 H=8 D=128 T=2048 0.203 ms 0.105 ms 1.94× 80.1 GB/s
B=2 H=8 D=128 T=1024 0.348 ms 0.088 ms 3.94× 95.0 GB/s

Single-token inference is entirely memory-bandwidth bound — the query vector is tiny, the KV cache is large. Triton's fused online-softmax kernel reads the cache once with no intermediate materialisation.


Repository Structure

triton-llm-kernels/
├── requirements.txt
│
├── normalization/                   # Transformer normalization layers
│   ├── triton_rmsnorm.py            # Fused RMSNorm (LLaMA / Mistral)
│   ├── triton_layernorm.py          # Fused LayerNorm (GPT-2 / BERT)
│   ├── triton_residual_rmsnorm.py   # Fused residual add + RMSNorm
│   ├── rmsnorm.py / layernorm.py    # PyTorch references
│   └── test_*.py
│
├── attention/                       # Attention mechanisms
│   ├── flash_attention_triton.py    # FlashAttention: multi-head, causal mask
│   ├── blocked_attention.py         # Online softmax in Python (reference)
│   ├── naive_attention.py           # Naive O(T²) attention (reference)
│   └── test_*.py
│
├── mlp/                             # MLP / FFN layers
│   ├── triton_bias_gelu.py          # Fused bias + exact GELU
│   ├── gelu.py / linear.py          # PyTorch references
│   └── test_*.py
│
├── optimizers/                      # Training optimizers
│   ├── triton_adam.py               # Fully fused AdamW kernel
│   ├── adam.py                      # PyTorch reference
│   └── test_*.py
│
├── inference/                       # Inference-specific kernels
│   ├── triton_inference_attention.py  # Single-token attention + KV cache
│   ├── kv_cache.py                  # KV cache layout
│   ├── attention_inference*.py      # PyTorch references
│   └── test_*.py
│
└── benchmarks/                      # Head-to-head benchmarks
    ├── attention_bench.py
    ├── mlp_bench.py
    ├── optimizer_bench.py
    ├── inference_bench.py
    └── utils.py

Implemented Primitives

Normalization

  • RMSNorm — single-pass fused kernel; used in LLaMA, Mistral, Qwen
  • LayerNorm — mean + variance in one pass; used in GPT-2, BERT, T5
  • Fused Residual RMSNorm — residual add + RMSNorm in one kernel; returns both normalised output y and pre-norm residual sum u for the next block's skip connection

Attention

  • Naive attention — O(T²) memory PyTorch reference
  • Blocked attention — online softmax in Python; shows algorithmic correctness alone doesn't help without kernel fusion
  • FlashAttention (Triton) — tiled O(T) memory; multi-head; optional causal mask for autoregressive generation
  • Inference attention (Triton) — single-token decoding against full KV cache; multi-head; online softmax in Triton

MLP

  • Fused Bias + GELU — bias add and exact GELU in one kernel; eliminates intermediate HBM write; up to 14.65× faster than unfused at large O

Optimizers

  • Fused AdamW — weight update, first moment, second moment, and weight decay in one kernel dispatch; up to 3.45× faster than PyTorch at 50M params

Inference

  • KV cache — pre-allocated (B, T_max, H, D) layout matching production inference engines
  • Triton inference attention — attends single query token to full KV cache; multi-head

Getting Started

pip install -r requirements.txt

Requires a CUDA-capable NVIDIA GPU (Ampere or newer recommended).

Run all tests:

python normalization/test_rmsnorm.py
python normalization/test_layernorm.py
python normalization/test_residual_rmsnorm.py

python attention/test_flash_attention.py
python attention/test_blocked_attention.py

python mlp/test_triton_bias_gelu.py

python optimizers/test_triton_adam.py

python inference/test_triton_inference_attention.py
python inference/test_kv_cache.py

Run benchmarks:

python benchmarks/attention_bench.py
python benchmarks/mlp_bench.py
python benchmarks/optimizer_bench.py
python benchmarks/inference_bench.py

The first run of any Triton kernel will be slow (~30s) while Triton JIT-compiles and caches the kernels. Subsequent runs are fast.


Testing

Validation results

FlashAttention kernel validation
-----------------------------------------------------------------
  [PASS] B=1 H=1 T= 128 D= 64         | max_diff=4.88e-04
  [PASS] B=2 H=4 T= 128 D= 64         | max_diff=4.88e-04
  [PASS] B=1 H=1 T= 256 D= 64         | max_diff=4.88e-04
  [PASS] B=2 H=4 T= 128 D= 32         | max_diff=4.88e-04
  [PASS] B=1 H=1 T= 128 D= 64 causal  | max_diff=9.77e-04
  [PASS] B=2 H=4 T= 128 D= 64 causal  | max_diff=9.77e-04
-----------------------------------------------------------------
  6/6 passed ✓

Triton inference attention validation
-------------------------------------------------------
  [PASS] B=1 H=1 D=  64 T= 1024 | max_diff=2.67e-05
  [PASS] B=2 H=4 D=  64 T=  512 | max_diff=1.11e-04
  [PASS] B=1 H=8 D= 128 T= 2048 | max_diff=3.02e-05
  [PASS] B=2 H=4 D=  32 T=  256 | max_diff=1.13e-04
-------------------------------------------------------
  4/4 passed ✓

AdamW kernel validation
-----------------------------------------------------------------
  [PASS] N=       1,024 | diff_w=2.38e-07  diff_m=5.96e-08  diff_v=1.18e-07
  [PASS] N=       4,096 | diff_w=2.38e-07  diff_m=8.94e-08  diff_v=2.07e-07
  [PASS] N=  10,000,000 | diff_w=4.77e-07  diff_m=1.19e-07  diff_v=3.71e-07
-----------------------------------------------------------------
  3/3 passed ✓

Triton Bias+GELU kernel validation
----------------------------------------------------
  [PASS] B=2 T=    4 O=    32 | max_diff=1.95e-03
  [PASS] B=2 T= 1024 O=  4096 | max_diff=3.91e-03
  [PASS] B=4 T=  512 O= 16384 | max_diff=3.91e-03
----------------------------------------------------
  3/3 passed ✓

Tolerance rationale

Kernel ATOL Reason
RMSNorm 5e-3 fp16 rounding accumulates with D; single reduction
LayerNorm 1e-2 Two reductions (mean + variance) compound at large non-pow2 D; matches PyTorch's own fp16 test tolerances
Residual RMSNorm 1e-3 Fusion doesn't add error — stays in fp32 registers
FlashAttention 1e-2 Online softmax; matches fp32 reference within fp16 precision
Bias + GELU 1e-2 fp16 erf; consistent with PyTorch's own fp16 tolerances
AdamW 1e-5 fp32 throughout; near-exact

All test scripts exit with code 1 on failure — CI compatible.


Design Principles

1. Memory bandwidth first, FLOPs second. Normalization, elementwise ops, and attention softmax are all bandwidth-bound. Kernel fusion reduces HBM round-trips — that is where the speedup comes from.

2. Correctness before performance. Every Triton kernel is validated against a fp32 PyTorch reference before benchmarking. Tests cover non-power-of-two dimensions that expose common tiling bugs.

3. No hidden abstractions. Strides, tile sizes, and masking are explicit. You can read any kernel and understand exactly what is happening in memory.

4. Fusion only where justified. Residual + RMSNorm is fused because both ops are bandwidth-bound and always co-located in transformer blocks. GEMM is not fused because it is compute-bound — cuBLAS handles it better.

5. Production-compatible interfaces. triton_residual_rmsnorm returns (y, u) matching flash-attn and vLLM conventions. triton_inference_attention handles multi-head and arbitrary KV cache lengths.


References

About

LLM primitives rebuilt in Triton — FlashAttention 2.52×, fused AdamW 3.45×, Bias+GELU 14.65× faster than PyTorch

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages