A minimal, educational implementation of Ring Attention logic powered by custom OpenAI Triton kernels.
This repository demonstrates how to bridge FlashAttention-v2 (single-device optimization) and Ring Attention (distributed sequence parallelism) by exposing the LogSumExp (LSE) statistics from the Triton kernel to enable global Online Softmax merging.
Custom Triton Kernel: Modified FlashAttention-v2 kernel that exports L (LogSumExp) buffers to HBM. This is the mathematical prerequisite for Ring Attention's distributed reduction.
Ring Logic Simulation: A Python-level simulation of the Ring Attention workflow (Blockwise calculation -> Communication -> Merge) on a single GPU using torch.chunk.
Online Softmax Merging: Implements the numerically stable merging of partial attention outputs and LSE stats across "devices".
Correctness Verification: Validated against PyTorch's standard scaled_dot_product_attention.
- The Missing Link: L_ptr
Standard FlashAttention kernels only output the final result
This implementation modifies the Triton kernel signature to store the LogSumExp statistics:
# In ring_flash_attn_triton.py
tl.store(l_ptrs, m_i + tl.log(l_i))
- Ring Simulation
We simulate a ring of
Stationary Q: Each "device" holds a shard of Query.
Rotating K/V: Key and Value blocks rotate through the ring.
Partial Compute: Calculate local Attention and LSE using Triton.
Global Merge: Update global Output and LSE using the Online Softmax trick.
Clone the repository and run the simulation script:
git clone [https://github.com/lyj20071013/Triton-Ring-FlashAttn.git](https://github.com/lyj20071013/Triton-Ring-FlashAttn.git)
cd Triton-Ring-FlashAttn
python ring_flash_attn_triton.py
Expected Output
=== Starting Ring Attention Simulation (Virtual Devices: 4) ===
[Device 0] Processing...
[Device 1] Processing...
[Device 2] Processing...
[Device 3] Processing...
Running PyTorch SDPA...
Running Ring Attention Simulation...
Max Error: 0.000xxx
Ring Attention Test Passed!