Skip to content

hustvl/MoDA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 

Repository files navigation

Mixture-of-Depths Attention

Scaling Attention Along the Depth

Lianghui Zhu1,2, Yuxin Fang2,†, Bencheng Liao1,2, Shijie Wang2, Tianheng Cheng2, Zilong Huang2, Chen Chen2, Lai Wei2, Yutao Zeng2, Ya Wang2, Yi Lin2, Yu Li2, Xinggang Wang1,#

1 School of EIC, Huazhong University of Science & Technology, 2 ByteDance Seed

() project lead, (#) corresponding author.

ArXiv Preprint (arXiv 2603.15619)

News

  • Mar. 16th, 2026: We released the Mixture-of-Depths Attention paper on arXiv. Code is coming soon.

TODO

  • Release Mixture-of-Depths Attention (MoDA) paper on arXiv.
  • Release MoDA Triton kernel and corresponding test units.
  • Release Chunk-Visible MoDA Triton kernel and corresponding test units.
  • Release Non-Causal MoDA Triton kernel and corresponding test units.
  • Release full LLM training recipe and reproducible configs.
  • Release full vision tasks training recipe, i.e., Classification on ImageNet.

Abstract

Scaling depth is a key driver for large language models (LLMs). Yet, as LLMs become deeper, they often suffer from signal degradation: informative features formed in shallow layers are gradually diluted by repeated residual updates, making them harder to recover in deeper layers. We introduce mixture-of-depths attention (MoDA), a mechanism that allows each attention head to attend to sequence KV pairs at the current layer and depth KV pairs from preceding layers. We further describe a hardware-efficient algorithm for MoDA that resolves non-contiguous memory-access patterns, achieving 97.3% of FlashAttention-2's efficiency at a sequence length of 64K. Experiments on 1.5B-parameter models demonstrate that MoDA consistently outperforms strong baselines. Notably, it improves average perplexity by 0.2 across 10 validation benchmarks and increases average performance by 2.11% on 10 downstream tasks, with a negligible 3.7% FLOPs computational overhead. We also find that combining MoDA with post-norm yields better performance than using it with pre-norm. These results suggest that MoDA is a promising primitive for depth scaling.

Overview Comparison

Conceptual comparison of mechanisms that utilize the depth stream. (a) Depth Residual reads the current representation and writes back by addition. (b) Depth Dense reads a set of historical representations and linearly projects them back; it writes back by concatenation along depth. (c) Depth Attention uses attention to read historical depth KV pairs in a data-dependent way. (d) Mixture-of-Depths Attention (MoDA) combines depth attention with standard sequence attention and writes both the current layer's output and its KV pairs to depth streams for subsequent layers.

Hardware-Efficient Implementation

Left: Flash-compatible hardware-efficient MoDA achieves higher efficiency than torch-implemented MoDA. However, it keeps a depth KV cache of length T×L for each sequence, so each query potentially scans a long concatenated depth KV. Right: Chunk/Group-aware MoDA groups queries by chunk size C and reorganizes depth KV by chunk, reducing the effective depth span from T×L to (C×L)/G per chunk, where G is the GQA group number. This layout improves depth KV calculation efficiency and reduces memory access overhead.

Results

Downstream Performance (400B tokens)

Model PIQA HellaSwag WinoGrande OpenBookQA BoolQ SciQ ARC-E ARC-C COPA MMLU Avg
OLMo2-700M 73.72 58.77 55.33 35.60 56.24 89.50 66.84 33.44 77.00 24.69 57.11
MoDA-700M 73.39 59.19 60.22 37.20 59.33 89.60 67.37 34.78 82.00 25.61 58.87
OLMo2-1.5B 76.55 65.86 63.22 38.80 63.61 90.60 72.98 42.47 81.00 27.73 62.28
MoDA-1.5B 76.82 66.24 65.59 41.60 67.34 92.10 72.81 46.82 85.00 29.59 64.39

Validation Perplexity (Lower is Better)

Model C4 ICE m2d2-s2orc Pile Wiki-text Books CC peS2o Reddit Stack Avg
OLMo2-700M 18.32 17.43 24.37 9.53 12.26 16.78 20.53 9.17 23.84 3.93 15.61
MoDA-700M 18.29 17.24 23.64 9.48 12.06 16.58 20.52 9.14 23.75 3.90 15.46
OLMo2-1.5B 16.16 15.37 21.10 8.45 10.41 14.19 18.13 8.19 21.21 3.57 13.67
MoDA-1.5B 15.97 15.08 20.92 8.33 10.16 13.95 17.88 8.09 20.85 3.52 13.47

Kernel Efficiency (A100, bf16, Forward & Backward, B=1, d=64, C=64)

Scaling Sequence Length T (G=8, Hq=64, Hk=8, L=64)

T FA2-Triton (ms) MoDA-Triton (ms) Depth Utilization Extra Time
4096 7.970 10.750 12.50% 25.86%
8192 28.700 35.427 12.50% 18.99%
16384 116.700 127.661 12.50% 8.59%
32768 459.854 480.914 12.50% 4.38%
65536 1831.668 1883.026 12.50% 2.73%

Scaling GQA Group Size G (T=16384, Hk=8, L=64)

G Hq FA2-Triton (ms) MoDA-Triton (ms) Depth Utilization Extra Time
2 16 28.982 39.741 3.12% 27.07%
4 32 58.071 68.939 6.25% 15.76%
8 64 116.700 127.661 12.50% 8.59%
16 128 233.700 244.900 25.00% 4.57%
32 256 467.107 480.767 50.00% 2.84%

Scaling Model Depth L (T=16384, G=8, Hq=64, Hk=8)

L FA2-Triton (ms) MoDA-Triton (ms) Depth Utilization Extra Time
64 116.700 127.661 12.50% 8.59%
128 116.700 138.224 12.50% 15.57%
256 116.700 167.958 12.50% 30.52%

MoDA reaches 97.3% of FlashAttention-2 efficiency at a sequence length of 64K. Extra time consistently decreases as sequence length or group size increases.

Attention Visualization

MoDA attention heatmaps with the combined-softmax formulation. Columns correspond to uniformly sampled layers {0, 11, 23, 35}, and rows correspond to randomly selected heads in each layer. The first column shows attention over Sequence KV only, while the remaining columns show the concatenated Sequence KV | Depth KV; the red dashed line marks the boundary between the two KV blocks. Across layers and heads, substantial attention mass is consistently assigned to the Depth KV block, indicating that MoDA effectively leverages depth information in addition to standard sequence attention.

Installation

The following requirements should be satisfied:

Install the local MoDA-enabled fla package from this repository:

cd libs/moda_triton
pip install -e .
cd ../..

Note: Please install from libs/moda_triton instead of PyPI, since the MoDA Triton kernels are maintained in this local directory.

Test Your MoDA

python3 libs/moda_triton/fla/ops/moda/moda_v14.py

Acknowledgement ❤️

This project is based on OLMo2 (paper) and Flash Linear Attention (paper). Thanks for their wonderful works.

Citation

If you find MoDA useful in your research or applications, please consider giving us a star ⭐ and citing it by the following BibTeX entry.

@article{zhu2026moda,
  title   = {Mixture-of-Depths Attention},
  author  = {Zhu, Lianghui and Fang, Yuxin and Liao, Bencheng and Wang, Shijie and Cheng, Tianheng and Huang, Zilong and Chen, Chen and Wei, Lai and Zeng, Yutao and Wang, Ya and Lin, Yi and Li, Yu and Wang, Xinggang},
  journal = {arXiv preprint arXiv:2603.15619},
  year    = {2026}
}

About

An hardware-aware Efficient Implementation for "Mixture-of-Depths Attention".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors