Skip to content

lyj20071013/Triton-Ring-FlashAttn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 

Repository files navigation

Triton-Ring-FlashAttn ⚡️

A minimal, educational implementation of Ring Attention logic powered by custom OpenAI Triton kernels.

This repository demonstrates how to bridge FlashAttention-v2 (single-device optimization) and Ring Attention (distributed sequence parallelism) by exposing the LogSumExp (LSE) statistics from the Triton kernel to enable global Online Softmax merging.

Key Features 🚀

Custom Triton Kernel: Modified FlashAttention-v2 kernel that exports L (LogSumExp) buffers to HBM. This is the mathematical prerequisite for Ring Attention's distributed reduction.

Ring Logic Simulation: A Python-level simulation of the Ring Attention workflow (Blockwise calculation -> Communication -> Merge) on a single GPU using torch.chunk.

Online Softmax Merging: Implements the numerically stable merging of partial attention outputs and LSE stats across "devices".

Correctness Verification: Validated against PyTorch's standard scaled_dot_product_attention.

How It Works 💡

  1. The Missing Link: L_ptr

Standard FlashAttention kernels only output the final result $O$. To implement Ring Attention, we need to merge results from multiple blocks (potentially across different GPUs).

This implementation modifies the Triton kernel signature to store the LogSumExp statistics:

# In ring_flash_attn_triton.py
tl.store(l_ptrs, m_i + tl.log(l_i))
  1. Ring Simulation

We simulate a ring of $N$ devices on a single GPU by chunking the Sequence dimension. The logic follows the standard Ring Attention pattern:

Stationary Q: Each "device" holds a shard of Query.

Rotating K/V: Key and Value blocks rotate through the ring.

Partial Compute: Calculate local Attention and LSE using Triton.

Global Merge: Update global Output and LSE using the Online Softmax trick.

Quick Start ⚡️

Clone the repository and run the simulation script:

git clone [https://github.com/lyj20071013/Triton-Ring-FlashAttn.git](https://github.com/lyj20071013/Triton-Ring-FlashAttn.git)
cd Triton-Ring-FlashAttn
python ring_flash_attn_triton.py

Expected Output

=== Starting Ring Attention Simulation (Virtual Devices: 4) ===
  [Device 0] Processing...
  [Device 1] Processing...
  [Device 2] Processing...
  [Device 3] Processing...

Running PyTorch SDPA...
Running Ring Attention Simulation...

Max Error: 0.000xxx
Ring Attention Test Passed!

About

A minimal, educational implementation of Ring Attention logic using custom OpenAI Triton kernels. Supports blockwise computation and online softmax merging.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages