PyTorch https://pytorch.org Sat, 14 Mar 2026 22:57:44 +0000 en-US hourly 1 https://wordpress.org/?v=6.9.1 https://pytorch.org/wp-content/uploads/2024/10/cropped-favicon-32x32.webp PyTorch https://pytorch.org 32 32 Building Voice Agents with ExecuTorch: A Cross-Platform Foundation for On-Device Audio https://pytorch.org/blog/building-voice-agents-with-executorch-a-cross-platform-foundation-for-on-device-audio/ Sun, 15 Mar 2026 16:00:45 +0000 https://pytorch.org/?p=58616 TL;DR
  • Open source voice models are proliferating, but there’s no unified native inference platform for voice agent workloads (transcription, real-time streaming, diarization, voice activity detection, live translation) across devices and hardware.
  • ExecuTorch fills this gap. As a general-purpose PyTorch-native inference platform, it enables developers to export voice models directly from PyTorch and run them across CPU, GPU, and NPU on Linux, macOS, Windows, Android, and iOS.
  • We provide reference implementations for five voice models spanning four distinct tasks, with working C++ application layers and mobile apps ready to build on. LM Studio is already shipping voice transcriptions powered by ExecuTorch in production.

Voice on the Edge Today

AI agents are increasingly expected to hear and speak. Whether it’s a personal assistant on smart glasses, a real-time translator on a phone, or a voice-driven coding companion on a laptop, voice is becoming a key modality for how agents interact with users. A voice-capable agent needs more than just offline transcription: it needs streaming speech recognition, speaker diarization, voice activity detection, noise suppression, speech-to-text, live translation, and full-duplex support, all running locally with low latency.

This demand is fueling a wave of open source voice models. In just the past few months we’ve seen Qwen3-ASR, Parakeet ASR, Voxtral Realtime, Kyutai Hibiki-Zero, Kokoro TTS, SAM-3-Audio, Liquid LFM2.5-Audio, Sortformer Diarization, and many more. What’s missing is a uniform way to deploy them natively on edge devices, as compiled C/C++ libraries that run directly on device hardware without a Python runtime or cloud dependency.

Most of these models can run in Python, but production level edge deployments require native C++ libraries. Existing native solutions tend to be either model-specific C++ rewrites that need to be rebuilt for each new architecture, or platform-specific frameworks tied to a single hardware ecosystem. As voice models diversify in architecture and complexity, neither approach scales.

We built ExecuTorch as a general-purpose native inference platform that works across models, backends, and devices. Last year we reached general availability with production-ready support including LLMs, vision, and multimodal models. Now we’re extending the same platform to voice. We see voice as a key frontier for on-device AI, and we wanted to prove that ExecuTorch’s architecture could handle the diversity of voice workloads across diverse hardwares. In this post, we provide reference implementations for five voice models spanning four distinct tasks, along with sample applications and mobile apps ready to build on. LM Studio is already shipping voice transcription powered by ExecuTorch in their desktop application.

Design Principles

Three principles underpin this approach:

Minimal model changes, not full rewrites. The model author’s PyTorch code is the starting point. Instead of rewriting models in other languages or converting them to other formats, we use torch.export() directly on the original PyTorch model’s core components (audio encoder, text decoder, token embedding, mel spectrogram) with minimal edits. For example, when Mistral released Voxtral Realtime and NVIDIA published Parakeet TDT and Sortformer, we exported their PyTorch source directly with targeted edits to satisfy torch.export() constraints. No format conversion, no reimplementation in C++.

Export the model, orchestrate in C++. The model and the application logic live in different layers. Model components are exported into a compiled artifact. A thin C++ application layer ties everything together, handling the complex orchestration: streaming-window bookkeeping, audio overlap handling, spectrogram alignment, stateful decoding loops. ExecuTorch handles the hard part: efficient inference across hardware backends.

Write once, run on any backend. One export serves every target platform.  The same exported model runs on XNNPACK (CPU), Metal Performance Shaders (Apple GPU), CUDA (NVIDIA GPU), or Qualcomm (NPU) with minimal backend-specific logic in the model or export script. Quantization (int4, int8) is applied in PyTorch before export, shrinking models significantly without manual kernel work.

Voice Models in Practice

We’ve validated this approach across five voice models with very different architectures:

Voxtral Realtime (streaming transcription, ~4B params). Mistral’s streaming transcription model delivers real-time transcription with offline-level accuracy, and is a good example of the “export the model, orchestrate in C++” approach. The C++ application layer handles audio signal processing: overlapping audio windows with past context and lookahead, spectrogram frame alignment, and encoder position tracking. The exported model handles the heavy compute: transformers with ring-buffer KV caches that enable unlimited-duration streaming within fixed memory. All streaming constants are derived at export time and baked into the exported model as self-describing metadata. Int4 quantization shrinks the model from 20GB to 5–6GB.

Parakeet TDT (offline transcription, 0.6B params). NVIDIA’s high-accuracy speech recognition model uses a Token-and-Duration Transducer architecture, where the model predicts both what token to emit and how far to advance in the audio at each step. This non-standard decoding loop is a good example of ExecuTorch’s multi-method export: the encoder, decoder, and joint network are exported as three separate methods in a single artifact, while the C++ application layer implements the TDT-specific greedy decode with LSTM state management. The application layer also includes timestamp extraction in C++ (word boundaries, sentence segmentation), making this a fully standalone on-device transcription pipeline.

Sortformer (speaker diarization, 117M params). NVIDIA’s diarization model answers “who spoke when” for up to four speakers in an audio stream. The model itself is stateless: it takes audio embeddings in and outputs per-frame speaker probabilities. All streaming complexity lives in the C++ application layer: a speaker cache that retains the most discriminative frames, a sliding FIFO window for short-term context, and cache compression that drops the least informative frames when memory fills up. This is one of the clearest demonstrations of ExecuTorch’s separation between model and orchestration.

Whisper (offline transcription, 39M–1.5B params). OpenAI’s widely adopted speech recognition model, with the widest backend coverage in ExecuTorch (CPU, Apple GPU, NVIDIA GPU, and Qualcomm NPU).

Silero VAD (voice activity detection, 2MB). A lightweight model that detects whether someone is speaking. A building block for any voice agent, and a good starting point for contributors.

Model Task Backends Platforms
Parakeet TDT Transcription XNNPACK, CUDA, Metal Performance Shaders, Vulkan Linux, macOS, Windows, Android
Voxtral Realtime Streaming Transcription XNNPACK, Metal Performance Shaders, CUDA Linux, macOS, Windows
Whisper Transcription XNNPACK, Metal Performance Shaders, CUDA, Qualcomm Linux, macOS, Windows, Android
Sortformer Speaker Diarization XNNPACK, CUDA Linux, macOS, Windows
Silero VAD Voice Activity Detection XNNPACK Linux, macOS

Sample Applications

Beyond model enablement, we’ve built a few end-to-end applications to demonstrate what’s possible. These are starting points, and we encourage application developers to build on them for their own use cases:

Real-time transcription on desktop. The demo reads live audio from the microphone and outputs transcribed text as you speak, running entirely on-device. This is the foundation for voice input in any desktop application: coding assistants, note-taking tools, accessibility features. Download the dmg file and try the app today:

Standalone realtime voice transcription macOS application powered by ExecuTorch and Voxtral Realtime. Video

Speech recognition on Android. The Parakeet and Whisper Android apps let users record audio and transcribe it on-device. These are fully functional apps with model download, microphone recording, and transcription, available in the executorch-examples repository.

Voice transcription with timestamp on Android (Samsung Galaxy S24) powered by ExecuTorch and Parakeet. Video

Adoption Case Study in production: LM Studio

LM Studio is a popular desktop application for running LLMs locally. They recently added voice transcription to their product, powered by ExecuTorch running the Parakeet TDT model. LM Studio exposes transcription in the app UI, with an API endpoint coming soon. With this, LM Studio will be enabling developers to integrate local speech recognition into their workflows. They chose ExecuTorch for its cross-platform support and competitive performance, shipping on macOS (Metal Performance Shaders) and Windows (CUDA) from the same model and application layer.

LM Studio adopts ExecuTorch for cross-platform, on-device transcription

Get Involved

These reference implementations are starting points, and the landscape of voice models we want to support is much larger. Models like Qwen3-ASR, Kyutai Hibiki-Zero, Kokoro TTS, SAM-3-Audio, and Liquid LFM2.5-Audio are all PyTorch-native and natural candidates for ExecuTorch enablement. We want the community’s help to get there:

  • Adopt ExecuTorch for voice inference in your frameworks and applications.
  • Contribute new models — pick a voice model, export it, write an application layer, and open a PR. Live translation, speech enhancement, wake word detection, noise reduction, text-to-speech. The architecture is ready for all of them.
  • Contribute backends and platforms — help us close the remaining gaps and improve performance across hardware.

ExecuTorch isn’t just for voice. It’s the same platform powering on-device LLMs, vision models, and multimodal AI. 

Start building: ExecuTorch Documentation | ExecuTorch repo | ExecuTorch Examples | ExecuTorch Discord

Acknowledgement

This work wouldn’t have been possible without support and core contributions from PyTorch team members, including Bilgin Cagatay, Tanvir Islam, Hamid Shojanazeri, Siddartha Pothapragada, Jack Khuu, Kaiming Cheng, Nikita Shulga, Angela Yi, Bin Bao, Shangdi Yu, Sherlock Huang, Yanan Cao, Digant Desai, Anthony Shoumikhin, Mark Saroufim, Chris Gottbrath, Joe Spisak, Jerry Zhang, Supriya Rao

Thank you to Patrick von Platen from Mistral AI for building the Voxtral Realtime model, open sourcing it, and reviewing and testing our integration code.

]]>
MXFP8 Training for MoEs: 1.3x training speedup vs BF16 for Llama4 Scout on GB200 cluster using TorchAO and TorchTitan https://pytorch.org/blog/mxfp8-training-for-moes-1-3x-training-speedup-vs-bf16-for-llama4-scout-on-gb200-cluster-using-torchao-and-torchtitan/ Thu, 12 Mar 2026 17:17:13 +0000 https://pytorch.org/?p=58489 TL;DR We recently demonstrated a +30.2% training speedup for Llama4 Scout with equivalent convergence to bfloat16, by using MXFP8 MoE training primitives in TorchAO! This is ~81% of the theoretical max achievable speedup for this model training config with the given general matrix multiplications (GEMMs)  and grouped GEMMs converted to MXFP8 (roofline 1.37x). These experiments were performed on Crusoe Cloud.

In this blog, we will discuss:

  1. Training run results and how you can reproduce them with TorchTitan and TorchAO
  2. An illustrated deep-dive of the forward and backward pass of the dynamically quantized MXFP8 grouped GEMM in MoE training

Convergence experiment and results

Our training runs on a 64 node / 256 device GB200 cluster with TorchTitan Llama4 Scout demonstrated equivalent convergence to the bfloat16 training baseline. This is consistent with our scaling experiments with MXFP8 training for dense models. We used the training configuration below:

  • Model: Llama4 Scout
  • Dataset: C4
  • Sequence length: 8192
  • Local batch size: 1
  • Learning rate: 1e-4
  • LR scheduler warmup steps: 2000
  • Parallelisms:
    • FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts)
    • EP=16 (on routed experts)
  • Activation checkpointing mode: full (recompute all intermediate activations instead of storing them, to reduce peak memory requirements)
  • torch.compile enabled in TorchTitan on components: model, loss
  • mxfp8 applied to routed experts computation (grouped GEMMs)
  • mxfp8 applied to all linear layers except those matching the FQNs: output, router.gate, attention.wk, attention.wv 
    • output embedding projection is too sensitive to low precision, adversely impacts convergence
    • Wk/Wv are too small to see a net performance benefit from dynamic mxfp8 quantization

Versions

  • torch: 2.11.0.dev20260122+cu130
  • torchtitan: 0.2.1
  • torchao: 0.17.0+git41e02b5fb

Training loss curves demonstrating equivalent convergence to bf16 for 3k+ steps

We performed a long running convergence experiment (3k steps) to evaluate if the convergence behavior between bfloat16 baseline and mxfp8 is equivalent. In order to keep the time window on the cluster manageable we used a small local batch size of 1 for the run. The depicted loss curves show virtually identical training loss curves.

Performance benchmarks 

Next, to evaluate the achievable performance improvements with MXFP8, we increased the local batch size from 1 to 16, as is typical for improving GEMM efficiency in sparsely activated MoEs. This resulted in an end-to-end training speedup of +30.2% over bf16 with equivalent configs. 

Number of GPUs BF16 tokens/sec MXFP8 tokens/sec MXFP8 speedup vs BF16
256 5317 6921 30.2%

The engine powering this speedup is the _to_mxfp8_then_scaled_grouped_mm op that is ~1.8x faster than compiled bf16 for these shapes. Using this for routed experts results in a 1.43x faster MOE layer and 1.2x faster e2e training with Llama4 Scout vs compiled bf16. Additionally, when we use MXFP8 for the shared expert linear layers as well, we reach 1.3x e2e training speedup vs compiled bf16. See the Appendix for microbenchmark tables.

In the next section we will give an introduction of the TorchAO APIs as well as give further technical details on mxfp8 and its application in scaled grouped GEMMs.

TorchTitan Config for MXFP8 MoE training

For the results of the previous section we relied on TorchTitan as our training framework. To use MXFP8 training for MoEs in TorchTitan, check out the documentation which details the necessary configs and examples.

TorchAO MXFP8 MoE training APIs

If you’re not using TorchTitan, you can also use the TorchAO primitives directly. TorchAO recently added a prototype API, _to_mxfp8_then_scaled_grouped_mm, which does exactly what it sounds like: quantizes grouped GEMM inputs (activations and weights) to mxfp8, then does a scaled grouped GEMM with the mxfp8 operands, producing an output in the original precision. This primitive is differentiable so can be used for training out of the box. Check out the docs for detailed microbenchmarks, roofline analysis for shapes used in different popular models, and more.

The goal of this primitive is to achieve a net speedup over the bf16 grouped GEMM baseline. By dynamically quantizing the inputs to mxfp8, we can then use a mxfp8 scaled grouped GEMM, which achieves up to 2x higher TFLOPs/sec vs bf16. Thus, as long as our quantization kernels are sufficiently fast and don’t introduce excessive overhead, we should be able to achieve a net speedup, as shown in the diagram below:

Illustrated walkthrough of a forward and backward pass through dynamic MXFP8 quantization + scaled grouped GEMM

Let’s do a tour through the forward and backward pass, starting at the moment our input activations come in to go through the forward pass of the routed experts!

Our starting point is immediately before the routed expert computation in the MoE layer. To be clear, at this point in the execution of the MoE layer, the following steps have already happened:

  • The expert affinity scores for each token have been computed by the router (token choice routing)
  • Tokens have been “assigned” to top K experts based on the scores
  • For expert parallelism, all-to-all comms dispatch tokens to the device the target expert resides on
  • Each device does a token shuffle/permutation such that tokens are grouped by expert and the groups are sorted in the same order as the expert weights
    • This shuffle operation produces a tensor called the offsets tensor, which stores the end index of each group in the flattened 2d tensor of token group

So we have our high precision input activations and weights for the routed experts as shown below. 

MXFP8 quantization

This is the key point where we will now dynamically quantize the input activations and weights to MXFP8, giving us float8 e4m3 data and float8 e8m0 scale factor.

Each 1×32 chunk of high precision input data shares a single e8m0 scale factor value that is used to scale the values to fill the dynamic range of the float8 e4m3 data type.

Write per-group scale factors to blocked layout with group boundaries along rows (M)

To do efficient grouped GEMMs on Blackwell GPUs, we’ll want our kernel to use the hardware’s 5th generation tensorcores, which are key to maximizing the compute throughput (TFLOPs/sec) on these chips. These tensorcores are programmed using the tcgen05 family of PTX instructions, which are the assembly-like instructions that familiar kernel languages like CUDA will be lowered to as part of the compilation. 

For MXFP8 grouped GEMMs specifically, we’ll need to use the block scaled variant of the tcgen05.mma PTX instructions. This instruction has some peculiar requirements for the layout of the e8m0 scale factors in our MXFP8 data. 

Specifically, the scale factors must reside in tensor memory (TMEM) in an unconventional blocked layout, which can be seen in the NVIDIA docs here:

Source: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout 

Therefore, in our quantization kernels, we need write the scale factors to this layout so they are usable by the tenscores. To convert the scales from simple row-major layout to this blocked layout, we want to do 3 things:

  1. Pad each token group’s scales such that they are evenly divisible into 128×4 tiles.
  2. Within each 128×4 tile, apply a layout transformation from simple row major to a ((32,4), 4) layout, meaning we have 4 (32,4) tiles logically arranged horizontally to each other in memory. You can think of it as a (32,16) shaped tensor where each 16 byte “super row” now contains 4 rows, one from each tile. These superrows are contiguous in memory.
  3. Writes out these tiles with tile-level granularity to a “row of blocks” major layout, as shown below (e.g., we write out the full tile as contiguous memory before proceeding to the next tile along the row)

Here is a before/after diagram of this blocked layout transformation that visualizes it a bit better. For more details, you can refer to the NVIDIA documentation.

When accounting for groups, the actual memory layout looks like this. Groups are along the M dimension (rows) for the 2d-3d MXFP8 grouped GEMM in the forward pass:

Transformation to this unconventional layout will need to be done very carefully to avoid incurring excessive overhead that will kill our net speedup (or worse, cause a slowdown!).

But wait – there’s more! This is only accounting for a single scale factor. Remember, we are doing a grouped GEMM which parallelizes our routed expert GEMMs and executes them in a single kernel. The data and scales for each GEMM live in the same tensor, but all must be individually valid for the tcgen05.mma PTX instruction we talked about. This means that we cannot directly apply the layout transformation on the scale tensor directly; rather, we need to do this layout conversion on each individual group separately.

Furthermore, the token group sizes for each expert are dynamic, and known only on the device! This means from our host side PyTorch code we can dispatch kernels to transform each group’s memory layout separately, because the host does not have group size information. To get it would require a host device sync, causing a large gap (idle time) in our GPU kernel execution stream, which would be terrible for performance! Therefore, we must design a custom kernel that can perform this task purely on the device side efficiently.

It turns out this is a very interesting and unusual kernel to write – one that warrants a separate deep dive post. In the meantime, let’s move onto the next step of our MXFP8 Grouped GEMM after we’ve dynamically quantized our inputs and written the scale factors to per group blocked layout.

2d–3d MXFP8 Grouped GEMM for forward output

With our data and scales in place in the proper memory layouts, we are finally ready to do our MXFP8 Grouped GEMM between our 2d input activations and 3d weights! This gives us a 2d output where each token has been projected to the hidden dimension. Decomposed, it looks like this:

Backward pass; 2d–3d MXFP8 Grouped GEMM for input gradient computation

The gradient of the input is fairly similar to the 2D 3D scaled group GEMM for the forward output, so we won’t go into much detail here as this is already getting quite lengthy. Specifically, the formula is:

dgrad = dO @ weight

The gradient of our output has the same shape as our forward output. So it is a 2D tensor of shape (total_M, N). Our weight is still our weight, a 3D tensor of shape (E,N,K).

So we quantize our 2D output gradient the same way we quantized our input activations in the forward pass, and we quantize our weights very similarly. Except this time they are non-transposed, in row-major format, and we are writing them to per-expert column-major format. So the kernel is a bit different, but that’s the only difference.

Backward pass; 2d–2d MXFP8 Grouped GEMM for weight gradient computation

The gradient of the weight is more interesting to discuss. This is because it involves an entirely different type of grouped GEMM, with different challenges. The formula for calculating the gradient of the weights is:

dW = dO^T @ X

This is a 2d-2d grouped GEMM with shapes (N, total_M) @ (total_M,K). As you can see, the groups are now along the contracting dimension of the GEMMs! 

Writing per group scale factors to blocked format with groups along the contracting dimension (K)

How does this change things? Well, a kernel that converts scales to block format with groups along the m dimension, or the rows, is now no longer quite right for this case, where the groups are along the contracting or K dimension. This is because the input scale is in a row-major layout, and so slicing the tensor up into groups along column boundaries fundamentally changes the strides and makes the strides dynamic per group. So we need a kernel that handles the case of different strides per group. In this case, we need to calculate the number of 128×4 tiles along each row in the group, and our stride will be the stride per tile times the number of tiles along a row in that group. It sounds complicated, and it is, but here is a diagram to help visualize and make this a bit simpler to understand: 

Each individual GEMM in the grouped GEMM produces a 2d output which are stacked into the final 3d result. Decomposed, it looks like this:

And there we have it! At this point, hopefully you have a better understanding of what’s going on under the hood in both the forward and backward pass of a dynamically quantized MXFP8 Grouped GEMM! 

That’s all for now – we hope you enjoyed the read, and remember to check out the TorchAO MoE training docs which have benchmarks, examples, and more to get started with!

Future work

MXFP8 MoE training in TorchAO is still a prototype feature, and we are actively working on a few improvements before graduating it to stable, namely:

  • Unify APIs for MXFP8 training for dense and sparse/MoE models: today, TorchAO has separate APIs for converting a model’s nn.Linear layers and torch._grouped_mm ops to use MXFP8. We are working on unifying these APIs to simplify the UX.
  • MXFP8 for expert parallel comms: Furthermore, beyond just the MXFP8 grouped GEMM, we have prototypes of autograd functions for efficient training with expert parallelism that quantize to MXFP8 earlier, before the all-to-all comms and stay in MXFP8 through the grouped GEMMs, thus saving network bandwidth and producing a speedup. Stay tuned for a post on this soon!

Appendix

Microbenchmarks: net speedup of dynamic MXFP8 quantization + MXFP8 grouped GEMM versus bf16 grouped GEMM

Below are some microbenchmarks comparing the combined duration of the forward and backward pass of the autograd function powering MXFP8 MoE training, versus the bf16 torch._grouped_mm baseline, for shapes used in recent MoE model architectures.

M = local_batch_size * sequence_length

G = number_of_experts_on_local_rank

N, K = expert_dimensions

Llama4 Scout shapes

M,N,K,G BF16 forward + backward (microseconds) MXFP8 forward + backward (microseconds) MXFP8 speedup vs BF16
(128000, 8192, 5120, 1) 43140.20 23867.00 1.808x
(128000, 8192, 5120, 2) 39487.60 23359.00 1.690x
(128000, 8192, 5120, 4) 39189.20 23945.50 1.637x
(128000, 8192, 5120, 8) 37700.70 22170.60 1.700x

You can refer to the documentation for commands to reproduce these benchmarks on a B200 GPU.

MoE layer benchmarks

Microbenchmarks of a single MoE layer on a single B200 also show MXFP8 achieves up to 1.43x faster MoE layer execution vs the bf16 baseline.

Model total_M N K bf16 time (ms) mxfp8 time (ms) speedup
Llama4 16e 131072 8192 5120 275.270 192.420 1.431x

You can refer to the documentation for commands to reproduce these benchmarks on a B200 GPU.

]]>
KubeCon + CloudNativeCon + OpenInfra Summit + PyTorch Conference China 2026 CFP & Registration Now Open https://pytorch.org/blog/kubecon-cloudnativecon-openinfra-summit-pytorch-conference-china-2026-cfp-registration-now-open/ Thu, 12 Mar 2026 02:00:56 +0000 https://pytorch.org/?p=58484 KubeCon + CloudNativeCon + OpenInfra Summit + PyTorch Conference China 2026 - Submit to Speak and Register

Both the Call for Proposals and registration are now open for KubeCon + CloudNativeCon + OpenInfra Summit + PyTorch Conference China 2026, taking place 8-9 September in Shanghai.

This is THE place for cloud native, open infrastructure, and AI innovators to converge for the exchange of ideas, sharing production lessons, and shaping the future. Gather with fellow developers, architects, maintainers, researchers, and enterprise leaders to explore the open source technologies powering modern platforms and intelligent systems. You won’t want to miss the first ever PyTorch Conference China!

Multiple communities. One stage. Your voice.

We are seeking proposals for presentations, panels, tutorials, and lightning talks designed for real-time problem solving and deep discussion. Submit to speak on any one of the following topics:

  • AI + ML + Agentic AI + Data Systems
  • AI Infrastructure + Accelerators + Performance Engineering
  • Platform Engineering + Cloud Native Architecture
  • Application Development + Developer Experience
  • Operations + Observability + Reliability
  • Security + Privacy + Trusted Computing
  • Networking + Edge + Distributed Systems
  • Cloud Infrastructure + Virtualization + Storage
  • Emerging Technologies + Research + Advanced Topics
  • Community + Open Source + Getting Started
  • Hardware Enablement + Diversification

Learn more about all suggested topics, and mark your calendar because the Call for Proposals (CFP) closes on Sunday, 3 May. Submit to speak.

Register + Join Us

Accelerate your open source journey! There are three pass options that give you full access to all keynotes and breakouts, the Solutions Showcase, and activities:

  • Corporate: For those whose company is paying for their attendance, including for-profit companies and government employees.
  • Individuals: For those who are currently not working for a company, work for a non-profit or research institution, or are attending at their own expense.
  • Academics: Current full-time students and faculty/staff members. A valid copy of institution-issued ID required.

Register by 7 April + save with early bird pricing.

Sponsor the Event

Sponsorships are also open through Friday, 17 July. Get your company in front of a wide array of open source communities in the cloud native, OpenInfra, and AI spaces. Sponsoring this event allows you to gain valuable mindshare of a targeted audience while engaging with 1,000+ developers, architects, and technical leaders.

]]>
PyTorch at NVIDIA GTC 2026: Join Us in San Jose! https://pytorch.org/blog/pytorch-at-nvidia-gtc-2026/ Mon, 09 Mar 2026 20:19:39 +0000 https://pytorch.org/?p=58382

We’re excited to announce that PyTorch will have a strong presence at NVIDIA GTC 2026, from March 16-19, 2026 in San Jose! Whether you’re a seasoned PyTorch developer or just getting started, we invite you to join us for demos, talks, hands-on labs, and opportunities to connect with PyTorch core maintainers and community experts.

PyTorch - Meta booth at NVIDIA GTC 2026

Visit the Meta Booth — Booth #338

Meta will be presenting a booth with PyTorch. Stop by Booth #338 to meet PyTorch experts from Meta and the PyTorch Foundation. We’ll be showcasing:

  • Helion Demo: Get hands-on experience with Helion, a new PyTorch-native kernel authoring framework. Learn how to write custom kernels and autotune them for optimal performance on your target NVIDIA GPU.
  • ExecuTorch: See NVIDIA’s Parakeet speech-to-text model deployed using ExecuTorch for high-performance, Python-free, on-device inference, demonstrating an end-to-end edge pipeline from audio capture to transcription.

Featured Talk: From Kernels to Clusters — How PyTorch Powers High-Performance AI

On Monday, March 16, from 4:40 to 4:55 PM PT, join Alban Desmaison (PyTorch core maintainer at Meta) for “From Kernels to Clusters: How PyTorch Powers High-Performance AI” at GTC 2026.

In this featured GTC session, Alban Desmaison will walk through:

  • A broad overview of the PyTorch framework and its evolving ecosystem
  • How Meta and the broader community leverage PyTorch for high-performance and distributed AI workloads
  • Key performance wins and updates powering state-of-the-art research and products
  • What’s next for PyTorch: insights into the technical roadmap and upcoming features

Add this talk to your NVIDIA GTC agenda to gain a firsthand perspective from the PyTorch core team!

PyTorch Sessions & Talks

Monday, March 16

Time (PT) Session Speakers
3:30 PM Open source RL and Agents, highlighting TorchForge and OpenEnv (CoreWeave Booth) Hamid Shojanazeri (Meta)

Aaron Batilo (CoreWeave)

4:40 PM From Kernels to Clusters: How PyTorch Powers High-Performance AI Alban Desmaison (Meta)

Tuesday, March 17

Time (PT) Session Speakers
1:00 PM – 1:50 PM Turbocharge Your LLM Inference: Expert Strategies for Lightning-Fast, Scalable AI Deployment NVIDIA Experts

Wednesday, March 18

Time (PT) Session Speakers
3:00 PM – 4:45 PM Optimize PyTorch Models for High-Performance Inference With Nsight Deep Learning Designer (Hands-on Lab) Manoj Kumar Yennapureddy (NVIDIA)
4:00 PM – 4:40 PM Build Fault-Tolerant Distributed AI Training at Scale Shreya Gupta (NVIDIA)

Aravind Neelakantan (AWS)

Thursday, March 19

Time (PT) Session Speakers
8:00 AM – 9:45 AM Ultra Scale Runbook for PyTorch on NVIDIA GPUs for Training and Inference (Hands-on Lab) Syed Ahmed (NVIDIA), Tianyu Liu & Lu Fang (Meta)
2:00 PM – 2:40 PM Parameterized CUDA Graph Launch in PyTorch: CUDA Graphs Without the Pain Daniel Galvez (NVIDIA)

Pre-NVIDIA GTC: Helion Hackathon — Saturday, March 14

Arriving early or not attending NVIDIA GTC? Kick off the week with the PyTorch Helion Hackathon on Saturday, March 14 in San Francisco!

Join experts from Meta and NVIDIA for a hands-on hackathon focused on kernel authoring with Helion. Whether you’re new to writing GPU kernels or looking to push the boundaries of performance, this is a great opportunity to learn, build, and connect with the community.

👉 Register for the Helion Hackathon

We can’t wait to see you at NVIDIA GTC 2026! Come find us at Booth #338, attend our sessions, and join us for the Helion Hackathon.

See you in San Jose!

]]>
KernelAgent: Hardware-Guided GPU Kernel Optimization via Multi-Agent Orchestration https://pytorch.org/blog/kernelagent-hardware-guided-gpu-kernel-optimization-via-multi-agent-orchestration/ Fri, 06 Mar 2026 20:00:50 +0000 https://pytorch.org/?p=58155 KernelAgent

Summary

Recently, the PyTorch team released KernelAgent, an open agentic system achieving 100% correctness across all 250 L1/L2/L3 KernelBench tasks. In this post, we extend that work by adding a hardware-guided optimization layer to the existing framework. Building on the previous correctness-focused pipeline, KernelAgent integrates GPU hardware-performance signals into a closed-loop multi-agent workflow to guide the optimization for Triton Kernels.

We evaluate the kernels generated by KernelAgent on all 100 L1 KernelBench tasks. Overall, it achieved 2.02x speedup over generated kernels from earlier versions. On average, KernelAgent generated 1.56x speedup when compared to default torch.compile, outperforming 65 of 100 KernelBench L1 tasks and achieving 89% of the hardware roofline efficiency on the H100.

The optimization codebase is located at KernelAgent repo with documentation to get started. We also share a selection of end-to-end KernelAgent optimization artifacts in the open-source repo

Introduction

Optimizing GPU kernels is increasingly critical for modern AI workloads. As models grow larger and more specialized, performance is often bounded not by high-level algorithms, but by the efficiency of the kernels that implement them. Yet manual kernel optimization remains expertise-intensive, requiring deep knowledge of GPU architectures, memory hierarchies, and performance trade-offs. This challenge compounds as the number of kernels grows and as each new GPU architecture demands rethinking optimization strategies.

In practice, experienced kernel engineers follow a systematic workflow when optimizing kernels. They profile kernels using tools such as NVIDIA Nsight Compute, examine hardware performance counters to diagnose bottlenecks, and iteratively apply targeted optimizations. Is register pressure killing occupancy? Is the tiling strategy leaving memory bandwidth on the table? Does the kernel need an architectural redesign, not just parameter tuning? This process can require reasoning through multiple distinct kernel architectures each exposing a different bottleneck before converging on a design that saturates the hardware. While effective, this iterative cycle typically takes days or weeks.

Modern compiler stacks have made significant progress toward automating kernel generation. torch.compile captures computation graphs and generates Triton kernels using a combination of graph transformations, pattern matching, and compiler heuristics. Similar approaches in systems such as TVM and XLA cover many common kernel patterns and deliver strong out-of-the-box performance. However, most compiler heuristics are guided by static models rather than direct measurements from real hardware execution.

KernelAgent is designed to automate this diagnosis-driven optimization loop by grounding kernel optimization in real hardware signals. It targets forward-pass (inference) kernels, where latency and throughput directly impact serving costs and user experience. It is built on three core principles:

Ground everything in hardware metrics. Both bottleneck diagnosis and optimization prescriptions must be derived from real profiling data. 

Explore optimization paths in parallel. Given the same hardware signals, multiple valid optimization strategies may exist. KernelAgent evaluates these strategies concurrently, reducing wall-clock optimization time and synthesizing previous approaches into evolved algorithmic discoveries.  

Learn across rounds through shared memory. Optimization agents reflect on what succeeded and failed in each round, summarizing insights into a shared memory that guides subsequent iterations and prevents repeated dead ends.

KernelAgent Optimization Workflow 

 

KernelAgent Optimization WorkflowFigure 1: Overview of the optimization workflow. The optimization started by taking the input kernel as the baseline to optimize. ProfilerAgent (collect hardware signal), JudgeAgent (diagnose bottleneck), AnalyzeAgent (prescribe recommendations), Orchestrator Agent (synthesize knowledge), Optimization Manager (explore different optimization through multiple optimization agents), BenchmarkAgent (measure performance). Arrows show data flow between agents.  

KernelAgent automates the workflow that experienced kernel engineers already follow—profiling, diagnosing bottlenecks, proposing optimizations, and iterating—by decomposing it into a set of cooperating agents. Each agent is responsible for a well-defined stage of the optimization loop, and together they form a closed, hardware-guided feedback system.

Figure 1 illustrates the overall workflow. Starting from an input kernel, KernelAgent repeatedly profiles the kernel, diagnoses performance bottlenecks, prescribes architecture-aware optimizations, synthesizes optimization knowledge, explores alternative optimization paths in parallel, and measures each candidate. Arrows indicate the flow of information between agents across optimization rounds. At a high level, each optimization round consists of the following stages:

Profile → Diagnose → Prescribe → Orchestrate → Explore → Measure

Each stage produces structured outputs that feed directly into the next, enabling fast, data-driven iteration.

How Data Flows Through the System 

Profiling: Collecting Hardware Signals

The optimization loop begins with the Profiling Agent inspecting the input kernel using NVIDIA Nsight Compute (NCU). KernelAgent integrates NCU to capture hardware-level performance metrics, including but not limited to DRAM throughput and utilization, L2 cache hit rates, Warp occupancy and stall reasons, and compute and tensor core utilization, and Speed-of-Light (SOL) metrics. These metrics provide the empirical foundation for all downstream decisions.

Input: kernel code + input specification (shapes, dtypes)

Output: structured dictionary of hardware metrics

Sample output:

{
   "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active": 0.41,
   "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct": 5.63,
   "gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed": 48.86
  ...
}

Diagnosis: Identifying Bottlenecks with Roofline Analysis

The Diagnose Agent interprets profiling metrics to classify the kernel’s dominant performance bottleneck. It performs a roofline-style analysis using SOL metrics and combines this with LLM-based reasoning for root cause analysis.

Input: NCU metrics + current kernel code

Output: BottleneckReport with:

  • Primary bottleneck category
  • Efficiency percentage (max of compute/memory SOL)
  • Root causes with evidence (specific metrics cited)

Example diagnosis:

        "category": "memory",
        "summary": "Kernel is memory-bound at 70.3% DRAM throughput with significant long scoreboard stalls from memory latency",
        "reasoning": "The roofline analysis shows Memory SOL at 70.3% while Compute SOL is only 45.2%...",
        "root_causes": [
            { "cause": "High memory latency stalls due to long scoreboard waits blocking warp execution",
              "evidence": [
               {"metric": "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", "value": 37.69, "interpretation": "37.7% of warp stalls are due to waiting for memory operations (global/shared memory loads), indicating memory latency is a significant bottleneck"},
                    {"metric": "sm__warps_active.avg.pct_of_peak_sustained_active", "value": 30.08, "interpretation": "Only 30% warp occupancy suggests insufficient parallelism to hide memory latency"}
                ]...
            },

Prescribing Fixes: Architecture-Aware Recommendations

Given the diagnosed bottleneck, the Analyzer (Prescriber) Agent generates concrete, architecture-aware optimization recommendations. It combines bottleneck classification, GPU specifications (e.g., A100 vs H100) and the retrieved optimization patterns from a curated database. This enables KernelAgent to tailor recommendations to the target hardware.

Input: BottleneckReport + GPU specifications + Optimization Database +  kernel code

Output: A List of prescribed fixes with rationale

Example prescription:

"recommended_fixes": [
         {"fix": "Increase pipeline depth with more stages (num_stages=4-5) and reduce register pressure by using smaller BLOCK_K or enabling register spilling to shared memory", 
    "rationale": "More pipeline stages help hide memory latency by overlapping loads with computation. Reducing register usage from 91 per thread would allow more concurrent warps to better hide the 37.7% long scoreboard stalls and improve the 30% warp occupancy"}
...
        ]

Orchestration: Turning Analysis into Search Strategy

The Orchestrator Agent synthesizes current diagnostics with historical optimization data to formulate a concrete search strategy for the next round. It aggregates prior diagnoses, prescriptions, and outcomes, incorporating search strategy (beam search, greedy search, etc.) and determines which fixes to explore next. 

After each round, KernelAgent generates a structured self-analysis: Was the diagnosis correct? Did the fix address the root cause? What worked, and why? This information enables inference-time learning

Input: Prescription fix + Attempt History + Reflexion 

Output: Finalized optimization prompt 

Example reflexion:

"was_diagnosis_correct": true,
    "was_fix_effective": false,
    "expected_outcome": "...should reduce memory latency stalls by allowing more in-flight memory operations, improving memory throughput and reducing warp stalls",
    "actual_outcome": "Performance degraded significantly by 37.4% (1.0910ms → 1.4996ms)....
    "reasoning": "The fix backfired because: 1) Doubling BLOCK_N (128→256) and BLOCK_K (32→64) dramatically increased shared memory and register usage per block, likely reducing occupancy significantly....",
    "lessons": [
        "Increasing BLOCK_N and BLOCK_K together with num_stages creates compound pressure on shared memory and registers",
        ...
    ],
    "avoid_patterns": [
        "Simultaneously increasing multiple tile dimensions (BLOCK_N, BLOCK_K) along with pipeline stages",
        "...
    ],
    "try_patterns": [
        "Try smaller BLOCK_K (16 or 32) with increased num_stages to reduce register pressure while improving pipelining",
      ...

Exploration: Parallel Optimization 

The Optimization Manager executes the exploration phase. It maintains top-K performing kernels and spawns multiple optimization workers per kernel to explore different fixes in parallel. If one optimization path degrades performance, another worker exploring a different fix may succeed, preventing the search to get stuck at local minima. Each worker applies a different optimization, compiles the kernel, and passes it to the Measure Agent.

Input: Kernel candidates + Different optimization plan

Output: Compiled Optimized kernel for evaluation 

Example Result:

BeamSearch initialized: 2 kernels × 2 bottlenecks = 4 workers
---------------------------------------
Round 1: 4/4 workers succeeded
--------------------------------
Round 2: 3/4 workers succeeded
...

Measure correctness and performance 

The Benchmarking Agent validates correctness and measures real performance for each kernel variant produced during exploration. For each candidate kernel, the agent first runs correctness checks against a trusted reference implementation. Only kernels that pass verification are benchmarked. Performance measurements are conducted using a controlled benchmarking protocol to ensure stability and reproducibility.

Performance measurement:

  • Warmup iterations (default: 25) to exclude cold-start effects
  • Repeat iterations (default: 100) for stable measurements
  • Shared benchmark lock prevents GPU contention between workers

Input: Compiled kernel variant, reference implementation (pytorch), test shape input

Output: Correctness verdict. Measured kernel runtime

Example Result:

Round 1: 4 successful, best new: 7.8000ms
Round 2: 4 successful, best new: 4.0457m
Round 3: 4 successful, best new: 3.1118ms
...

Performance summary

We use triton.testing.do_bench to obtain consistent performance measurements, reporting the mean runtime over 100 repetitions (with >1s warmup) for each kernel variant on H100. Specifically, we compare KernelAgent against:

  • KernelAgent correctness-only loop generated kernels (our earlier baseline),
  • Out-of-the-box torch.compile, which refers to PyTorch Inductor with default mode, static shapes (no dynamic shape support enabled), and CUDA graphs disabled. 

Over the 100 L1 KernelBench problems, KernelAgent outperforms 65 out of 100 tasks relative to out-of-the-box torch.compile. Overall, KernelAgent achieves 2.02x geometric-mean speedup over the earlier correctness-only baseline, achieves 1.56x vs out-of-the box default torch.compile. It also achieves 89% of the hardware roofline efficiency on the Nvidia H100, where roofline efficiency is derived through max of compute SOL and memory SOL via Nvidia Compute Nsight, i.e., the higher of streaming multiprocessor or memory throughput as a percentage of hardware peak. 

We share a selection of end-to-end KernelAgent optimization artifacts in the open source repo. We also tested a few kernels across each category on different input shapes. Across 12 kernels / 144 shapes, we observe similar speedup. 

On the effect of test time scaling: 

Kernels making improvement by round

Figure 2: KernelAgent’s performance evolves as the number of optimization rounds increases. 

While a large fraction of performance gains are realized in the first round, reflecting the effectiveness of hardware-guided diagnosis and coarse-grained fixes, the system continues to make steady progress with additional rounds. As more rounds are allowed, KernelAgent is able to hill climb beyond the initial improvements, refining earlier optimizations and exploring secondary bottlenecks that only become visible after the primary ones are addressed. This behavior highlights the importance of iterative, feedback-driven optimization. 

Below, we present one end-to-end case study to better understand what optimization techniques KernelAgent is learning and applying in different rounds. 

Case Study: Matrix–Vector Multiplication (A @ x)

Operation: C = A @ x
Shape: M=2048, K=1,048,576
DTypes: BF16 inputs, FP32 accumulate, BF16 output
Hardware: H100  

Results overview:

PyTorch Compile baseline: 2.09 ms 

KernelAgent with correctness-only pipeline: 9.52 ms 

LLM baseline: direct prompt without hardware feedback. Each round takes the previous round’s output as input: (sequential exploration, 8 rounds, opus-4.5). Best result: 3.1985ms  

KernelAgent with optimization layer (4 workers, 8 rounds, opus-4.5). Best results 1.95 ms: 

Round-by-round kernel performance

Figure 3: Round-by-round kernel performance: KernelAgent with optimization layer vs. Direct prompt without hardware feedback for matrix-vector multiplication

Key insights

  1. The heuristics optimization knowledge in an LLM is effective, for example, “bigger blocks improve bandwidth”. Without performance feedback, however, these heuristics bring the kernel to a local minima and become ineffective as the LLM can’t perceive the performance tradeoff curve it’s navigating.
  2. Without structured exploration, the LLM is trapped in the trajectory of the seed kernel. It never considered switching from split-K to a simpler one-row-per-thread design and couldn’t outperform performance of eager.
  3. KernelAgent’s multi-worker exploration, profiling-based approach, and reflective knowledge sharing enable the exploration of different alternatives and find the optimized path. 

Why the Baseline Was Slow: The initial Triton kernel used a 2D tile with a vector accumulator. Profiling showed the kernel was primarily occupancy-limited by registers, so it couldn’t issue enough concurrent memory requests to hide DRAM latency.

First improvement identified by KernelAgent: 

  • Bottleneck: Underutilized SMs due to register-pressure–limited occupancy.
  • Prescription: Replace the large vector accumulator with scalar accumulators, process a small number of rows per program, and increase grid parallelism
  • Performance: 9.52 ms → 6.80 ms. Occupancy increased 8x, Memory SOL rose from 18.5% → 25.8%.
  • Reflexion: Reducing register state is necessary before any other optimization can take effect.
# NUM_ROWS=4: four scalar accumulators instead of a vector
acc0 = 0.0
acc1 = 0.0
acc2 = 0.0
acc3 = 0.0

for k0 in range(0, K, BLOCK_K):
    # Load B vector tile once [BLOCK_K]
    b = tl.load(b_ptrs, mask=k_mask, other=0.0).to(tl.float32)

    # Process each row individually with its own 1D load
    if row_start + 0 < M:
        a0 = tl.load(a_ptr + (row_start + 0) * stride_am + offs_k * stride_ak,
                      mask=k_mask, other=0.0).to(tl.float32)
        acc0 += tl.sum(a0 * b)
    if row_start + 1 < M:
        a1 = tl.load(a_ptr + (row_start + 1) * stride_am + offs_k * stride_ak,
                      mask=k_mask, other=0.0).to(tl.float32)
        acc1 += tl.sum(a1 * b)
    # ... (acc2, acc3 similar)

# Launch config: BLOCK_K=512, NUM_ROWS=4, num_warps=4, num_stages=4
# Grid: (cdiv(M, 4),) = (512,)

Second improvement identified by KernelAgent: 

  • Bottleneck: Still dominated by memory latency; improvements plateaued.
  • Prescription: Introduce limited caching / reuse for the vector x, reducing redundant global memory traffic. Avoid increasing num_stages, which had previously increased register pressure.
  • Performance: 6.80 ms → 6.20 ms. A modest gain from shared memory caching for the B vector, reducing redundant global memory accesses
  • Reflexion: Matrix–vector multiplication behaves very differently from GEMM; tiling strategies do not transfer directly.

Third improvement identified by KernelAgent:

  • Bottleneck: After reducing register pressure, performance was limited by inefficient memory transactions, not lack of warps.
  • Prescription: Return to vectorized 2D loads for better coalescing, but with careful control of registers: Smaller tile (BLOCK_M=32), Large K tile (BLOCK_K=512), and num_stages=1 to eliminate pipeline register overhead
  • Performance: 6.20 ms → 4.03 ms. 
  • Reflexion: Reducing register state is necessary before any other optimization can take effect.

 

### Before (Step 1 approach):
```python
# Sequential scalar accumulators, NUM_ROWS=4
acc0 = 0.0; acc1 = 0.0; acc2 = 0.0; acc3 = 0.0
# ...process rows one at a time with branching...
```

@triton.jit
def matvec_kernel(A_ptr, x_ptr, C_ptr, M, K, stride_am, stride_ak,
                  BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
    pid_m = tl.program_id(0)
    row_start = pid_m * BLOCK_SIZE_M
    row_offsets = row_start + tl.arange(0, BLOCK_SIZE_M)
    row_mask = row_offsets < M

    # Back to vector accumulator, but only 32 elements (not 128)
    acc = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)

    for k_start in range(0, K, BLOCK_SIZE_K):
        k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
        k_mask = k_offsets < K

        x_vals = tl.load(x_ptr + k_offsets, mask=k_mask, other=0.0)
        a_ptrs = A_ptr + row_offsets[:, None] * stride_am + k_offsets[None, :] * stride_ak
        a_vals = tl.load(a_ptrs, mask=row_mask[:, None] & k_mask[None, :], other=0.0)
        acc += tl.sum(a_vals.to(tl.float32) * x_vals.to(tl.float32)[None, :], axis=1)

    tl.store(C_ptr + row_offsets, acc.to(tl.bfloat16), mask=row_mask)

# Launch: BLOCK_SIZE_M=32, BLOCK_SIZE_K=512, num_stages=1, num_warps=4
# Grid: (cdiv(M, 32),) = (64,)

Final improvement: one-row-per-program (architectural shift) to saturate bandwidth

  • Bottleneck: Underutilized SMs due to register-pressure–limited occupancy.
  • Prescription: Make an architectural change: assign one row per program. Scalar accumulator (minimal registers), Massive grid parallelism (2048 programs), Pure 1D streaming loads, and Large BLOCK_K to amortize loop overhead
  • Performance: 4.03 ms → 1.95 ms. Warp active ~95%
  • Reflexion: This workload is fundamentally memory-bandwidth bound. Maximizing occupancy and parallelism is more important than tiling elegance. Architectural changes are sometimes required to escape local optima.
@triton.jit
def matvec_kernel(A_ptr, x_ptr, C_ptr, M, K, stride_am, stride_ak,
                  BLOCK_SIZE_K: tl.constexpr):
    pid_m = tl.program_id(0)
    if pid_m >= M:
        return

    # Scalar accumulator --- minimal register usage
    acc = 0.0
    a_row_ptr = A_ptr + pid_m * stride_am

    num_k_blocks = tl.cdiv(K, BLOCK_SIZE_K)
    for k_block in range(num_k_blocks):
        k_start = k_block * BLOCK_SIZE_K
        k_offsets = k_start + tl.arange(0, BLOCK_SIZE_K)
        k_mask = k_offsets < K

        x_vals = tl.load(x_ptr + k_offsets, mask=k_mask, other=0.0)
        a_vals = tl.load(a_row_ptr + k_offsets * stride_ak, mask=k_mask, other=0.0)

        prod = a_vals.to(tl.float32) * x_vals.to(tl.float32)
        block_sum = tl.sum(prod, axis=0)  # Scalar reduction
        acc += block_sum

    tl.store(C_ptr + pid_m, acc.to(tl.bfloat16))

# Launch: BLOCK_SIZE_K=1024, grid=(M,) = (2048,)
# No explicit num_warps or num_stages (defaults)

Memory throughput achieved

Figure 4: Memory throughput achieved (in GB/s and % of SOL) through KernelAgent’s Improvement

Lessons Learned

We would like to share our learnings when orchestrating multiple agents tackle complicated kernel engineering questions.

Q: How to keep agents on track without human oversights? 

The key is to have hard, verifiable constraints. In KernelAgent, correctness and performance are enforced through gated evaluation. Every kernel variant must pass numerical verification, and performance is measured using real hardware benchmarks. Agents stay on track when progress is defined by executable, measurable outcomes

Q: How to structure work so multiple agents can make progress in parallel, yet sharing the working context with each other so that future rounds of iteration will build on the shared memory? 

Parallelism alone is not sufficient; without coordination, agents quickly duplicate work or explore redundant paths. Within each round, optimization workers operate independently and in parallel, exploring different optimization strategies. After the round completes, their outcomes—successful or not—are summarized into a shared, structured context that captures what was attempted, what worked, and why. This shared memory is then broadcast to all agents in subsequent rounds.

Q: How to prevent the agents stuck in a local minimum, and clear signals to know when to stop?  

Avoiding local minima requires both diversity in exploration and clear termination criteria. KernelAgent maintains a beam of top-performing kernels rather than a single incumbent. Parallel exploration further reduces the risk that early suboptimal decisions dominate the search. 

Specific to GPU optimization, it can be stuck in sequential parameter optimization. While optimization A didn’t work and optimization B didn’t work, when combined together, it’s possible to see a performance breakthrough. The goal for KernelAgent is thus to maximize the exploration of ideas that can be proposed. 

KernelAgent monitors performance deltas and hardware utilization metrics. When successive rounds fail to produce meaningful improvements in roofline efficiency or runtime, the system concludes that further optimization is unlikely to yield returns.

Conclusion

KernelAgent demonstrates that the deep agent principles from the previous correctness-focused loop, including grounded tool use, parallel exploration, deterministic control, extend naturally to performance optimization. By adding hardware profiling and working memory to the loop, allowing multi-agent to learn and explore different optimization paths, we can push verified kernels from “correct” to “correct and fast.”

Try it yourself. KernelAgent is an open-source project under active development. We welcome feedback, contributions, and new use cases from the community, and we hope this work helps advance practical, scalable kernel optimization within the PyTorch ecosystem.

Acknowledgements

We would also like to thank the following people for feedback: Paulius Micikevicius, Yang Wang, Lu Fang, Jie Liu, Zacharias Fisches, Alec Hammond, Richard Li, Chris Gottbrath, Davide Italiano, Joe Spisak, and John Myles White. 

 

]]>
FlexAttention + FlashAttention-4: Fast and Flexible https://pytorch.org/blog/flexattention-flashattention-4-fast-and-flexible/ Thu, 05 Mar 2026 17:55:43 +0000 https://pytorch.org/?p=46919

TL;DR:

On Hopper and Blackwell GPUs, FlexAttention now has a FlashAttention-4 backend.

We added support in PyTorch to automatically generate CuTeDSL score/mask modification functions, and to JIT-instantiate FlashAttention-4 for custom attention variants.

This leads to performance gains of 1.2× to 3.2× over the existing Triton implementation on compute-bound workloads.

FlexAttention recap

FlexAttention is a PyTorch API that lets you implement custom attention variants in a few lines of Python, no CUDA required. You write a score_mod or mask_mod function that modifies attention scores, and the compiler handles the rest: ALiBi, sliding window, document masking, soft-capping, and combinations of these all work through the same interface.

Under the hood, it’s two extensions over vanilla FlashAttention:

  1. Pointwise modifications to pre-softmax scores, with arbitrary loads from global memory.
  2. Block-sparse iteration for both forward and backward, with a simple data structure for encoding data-dependent sparsity at runtime.

That’s it. Of course, the devil is in the details, but as we’ve shown in the original FlexAttention post and FlexAttention for inference, these two extensions cover a wide range of popular attention variants.

With this release, FlexAttention now has a FlashAttention-4 (FA4) backend. Here’s how to use it:

import torch
from functools import partial

from torch.nn.attention.flex_attention import flex_attention

flex_flash = torch.compile(
    partial(flex_attention, kernel_options={"BACKEND": "FLASH"}), dynamic=False
)

def local_boost(score, b_idx, h_idx, q_idx, kv_idx):
    return torch.where(torch.abs(q_idx - kv_idx) <= 8, score * 2, score)

B, H, S, D = 2, 8, 2048, 128
q = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, H, S, D, device="cuda", dtype=torch.bfloat16)
out = flex_flash(q, k, v, score_mod=local_boost)

Set BACKEND"FLASH" to use the FA4 backend. You’ll need a recent PyTorch nightly and a recent FlashAttention checkout; check the install docs for version compatibility. This is actively developed code; expect some breaking changes as it stabilizes.

Democratizing attention research with FlexAttention

FlexAttention was originally designed (and named) to provide flexibility to AI researchers when prototyping and experimenting with new attention variants. In practice, this has proven true: dozens of papers cite FlexAttention, and over a thousand repos have adopted it:

While Flex has been successful at enabling researchers, a constant refrain from users is that they ultimately hit a performance ceiling that’s hard to break. At the time of the original blog post, we compared against FlashAttention-3 (FA3) on Hopper GPUs and were roughly at 80% of its performance.

If you measure this today, FlexAttention achieves roughly 60% of FlashAttention-3’s throughput, despite improvements to both implementations!

A common pattern emerges: researchers experiment with Flex, find something that works, then hit a wall when performance becomes critical. At that point, experts must port it to something lower-level. FlashAttention-3 has continued to add new args that extend functionality, but every new arg/pattern requires low-level rewrites. We’ve pushed the burden from researchers to ML engineers!

The delta in performance between fully optimized versions is perhaps worth the gain in flexibility on Hopper, but it’s a different story on newer Blackwell GPUs.

Let’s look at a comparison between FlexAttention using its existing Triton implementation (with all the autotune knobs maxed out) on a Blackwell GB200 GPU (running at 1000W), versus a highly optimized implementation like cuDNN attention available through PyTorch’s SDPA:

What was once a small gap has grown to a chasm!

Blackwell: bigger tensor cores, bigger problems

On Blackwell, high-performance attention requires a deeply pipelined, warp-specialized kernel. These techniques aren’t expressible in our Triton-based implementation. We’d like to direct you to the great explanation of the FlashAttention kernel in Reverse engineering FlashAttention-4, which details how the implementation leverages new hardware capabilities on Blackwell, plus updates to how softmax is computed. The name of the game, as always, is to keep the tensor cores busy; since they’ve only gotten faster, this requires heavy use of deep async pipelines.

Blackwell introduces Tensor Memory (TMEM), a programmer-managed scratchpad close to the tensor cores for intermediate results. More importantly, both data movement and matmuls are now fully asynchronous. A warp can kick off a matmul or load and immediately move on.

Warp specialization splits work into stages: some warps handle synchronous work like softmax (which needs registers), while others orchestrate the async pipeline by issuing loads and matmuls, and coordinating synchronization. Because the orchestration warps have low register pressure, more operations can stay in flight simultaneously, hiding latency.

The tensor cores got bigger and faster, but the special-function unit (SFU), which handles operations like exponentials, didn’t keep pace. For forward attention, this shifts the bottleneck: softmax’s exp() is now as expensive as the matrix multiplies. To keep the GPU fully saturated, you need to ping-pong between two tiles, overlapping one tile’s matrix multiplies with the other’s exponentiation. The timeline below shows how these phases alternate to hide latency.

Backward is even trickier. There isn’t enough TMEM to hold all the accumulators at once, so the kernel needs careful pipelining to overlap computation with data movement while all shared, register, and tensor memory is under heavy pressure.

This is the kind of low-level choreography that a general-purpose compiler can’t easily discover. As the Gluon intro puts it: “While the Triton compiler does a good job of generating efficient code for a wide range of kernels, it can be beaten by hand-tuned low-level code. When this happens, there is little the user can do to significantly improve performance since all the details are hidden.” This is even harder for FlexAttention, which is a meta-attention implementation, so hardcoding compiler optimizations for specific patterns is difficult when the patterns are user-defined. Because of this, we started looking at lower-level implementations to see how best to improve performance on Blackwell.

FlashAttention-4 as the foundation

There was a lot of flux in attention implementations for new Blackwell hardware. cuDNN added performant attention support early, but FA3 (the existing SOTA implementation on Hopper) did not work on Blackwell. WGMMA no longer exists on SM100: it has been replaced with TCGEN05 tensor core instructions, and tensor core ops require different memory spaces (tensor memory).

Tri Dao et al. started working on FlashAttention-4, an updated version of the implementation that could take full advantage of the hardware.

One major change from FA3 to FA4 is CuTeDSL, a Python DSL recently released by the NVIDIA CUTLASS team for writing high-performance CUDA kernels using CUTLASS abstractions. SOTA attention implementations make heavy use of CUTLASS abstractions like cute.Layouts, but anyone who has tried to install FlashAttention knows how painful it can be, with long compile times. So while the idea of rewriting Flex in CUTLASS C++ has come up before, the dynamic nature of FlexAttention (and the overhead of compilation) made the premise less attractive. CuTeDSL enables authoring in Python what used to require CUTLASS C++, which makes JIT-style workflows more practical for FlexAttention.

After some early discussion with Tri Dao about this path, we decided to join forces on this implementation for both FlexAttention and FlashAttention-4.

Rather than building a separate implementation, we collaborated to extend FA4 directly, sharing the same async pipeline infrastructure and adding extension points where FlexAttention needs to inject score modifications and sparsity.

This meant adding score-modification support to both forward and backward (by inlining score mods into the FlashAttention implementation) and adding support for the block-sparse metadata used in FlexAttention.

This work roughly split into two: changes to FA4 to produce FlexAttention template instantiations, and updates to Inductor to generate the required CuTeDSL code from its PyTorch representation.

Inductor → CuTeDSL: the glue layer

So what do we need to produce for FlexAttention? Pointwise modifications and arbitrary loads. Luckily, this isn’t the first time Inductor has done this, and there’s an existing mechanism for this kind of extension. For instance, let’s take a look at a score modification that can be used to implement ALiBi (fun fact: this was the motivating example for the FlexAttention project).

Roughly speaking, torch.compile takes user code and transforms it through several IRs. These transformations produce increasingly lower-level representations. In the FX IR, you can still see familiar PyTorch operators, along with variables being set to None immediately after use. The AOTAutograd pass auto-generates the backward: since d/dx of (X + A) equals 1, the chain rule sends the gradient straight through.

Notably, no part of this stack needs to “know” what CuTeDSL code is until we get to Inductor, which ultimately produces the kernel code that gets run.

Click through the tabs below to see how ALiBi evolves from user code to the final CuTeDSL kernel.

Original FX IR AOTAutograd CuTeDSL




def alibi_mod(score, b, h, q_idx, kv_idx):
    scale = torch.exp2(-((h + 1) * 8.0 / H))
    bias = (kv_idx - q_idx) * scale
    return score + bias
User Code: A score modification that implements ALiBi, the motivating example for the flex-attention project.

Inductor lowers pointwise IR as a define-by-run function that calls V.ops.<op>, then swaps in a handler that reinterprets those calls for a target backend. In practice, this shows up as ops_wrapper(...) and OpsWrapper, which let you map unary and binary primitives to a new language without changing the IR itself. For CuTeDSL, we plug in a CuTeDSL handler that rewrites those ops into TensorSSA expressions, so arithmetic is performed on register(RMEM) backed cute tensors and expressions can be CSE’d.

We also add a specialized load path for “arbitrary loads.” If a user writes a score/mask mod that relies on some global tensor, we materialize an RMEM fragment, and emit a load at the (possibly indirect) index. This lets us bridge from Inductor’s index expressions to CuTeDSL’s TensorSSA.

Flexifying FlashAttention-4

We made two orthogonal extensions to FA4 so it can serve as FlexAttention’s backend:

  1. score modification in forward and backward
  2. block-sparse iteration in forward and backward

Both extensions are implemented in CuTeDSL so they can be inlined into the same async pipeline that makes FA4 fast.

Think of FlashAttention as processing a queue of KV tiles per SM. Flexification adds two hooks: block sparsity controls which tiles enter the queue (skipping empty blocks, marking partial blocks), while score/mask mods are applied as pointwise operations in the softmax warp.

With that split in mind, here’s how the forward and backward hooks fit in.

Score modification

One aspect of CuTeDSL that made this project feasible is the ability to not only pass a variadic number of kernel arguments to an implementation (which then lowers into a specific instantiation), but also to pass user callables directly. If you return to the core of FlexAttention, we need the ability to inject user modifications at precise points in the FlashAttention algorithm. We build on the existing FA4 implementation, which was already written to be score-mod friendly.

In forward, we bring the S tile back from TMEM into registers so we can apply the mod, compute row-wise max/sum, and generate the P tile for the second matmul. We define a CuTeDSL interface that mirrors FlexAttention’s score_mod signature and, instead of threading N variadic captures through the kernel, pass a list of aux_tensors that represent any global-memory regions used by the mod. Inside the kernel, we reinterpret register fragments as TensorSSA views (with optional vectorization) and inline the user callable on those tiles.

We already need S in registers to compute max/sum and form the P tile, so we apply score/mask mods while the data is resident in RMEM instead of adding a separate phase. That keeps the same pipeline structure and overlap between TCGEN and SFU work. Any extra reads from aux_tensors are issued directly when needed and scheduled alongside the existing stage that consumes S.

Backward follows the same interface shape with a generated score_mod_bwd callable, but the liveness story is different. In standard FA4, the S and dS tiles never need to be live at the same time, so TMEM can be shared across phases. With score mods, the backward path depends on what the user’s derivative needs.

If the gradient only depends on P (or the incoming gradient), we preserve the default schedule and still avoid S/dS overlap in TMEM. If the derivative depends on pre-softmax scores, we keep the needed S fragments in registers alongside P or dS and drop them as soon as their contribution is consumed. TMEM stays reserved for the main accumulators, and the cost is higher register pressure for those specific mods.

Block-sparse iteration (forward + backward)

The second change FlexAttention requires from FlashAttention is block-sparse iteration. We extend FA4's kernels to accept block-mask metadata (the row/column tiles to visit) and drive the tile scheduler from it, so the kernel only touches the (m, n) tiles present in the mask. We also made the block-sparse path work with GQA packing and broadcasted head dimensions.

One consequence of the two-tile ping-pong from earlier: the minimum sparse block size on Blackwell is 256×128, up from 128×128 on the Triton path. Because each CTA processes two M-tiles to keep the pipeline full (q_stage=2), the smallest unit of work the scheduler can skip is 256 rows, so the block-mask granularity has to match.

Backward walks the same block-mask, computing gradients only for tiles present in the forward pass. The backward kernel already uses subtile iteration along rows, so the 256-row constraint fits naturally.

What we contributed

These extensions required upstream updates across the FA4 stack:

  • Score-mod hooks in forward and backward, including SM90/SM100 correctness fixes and GQA edge cases
  • Block-sparse forward and backward paths for Blackwell and Hopper, plus pack-GQA support for broadcasted mask mods
  • Interface cleanups for contig layouts and expanded tensors in the score/mask-mod path
  • CuTeDSL bumps and TVM-FFI enablement to cut CPU dispatch overhead

With those pieces in place, let's look at the performance.

Results

Patterns supported by SDPA

For standard attention patterns like dense (noop) and causal masking, we can compare FlexAttention's new Flash backend against both the existing Triton implementation and cuDNN.

On GB200, the Flash backend achieves 1.6–3.2× speedup over Triton for forward passes and 1.85–2.3× for backward. For backward passes, Flash matches or even beats cuDNN in some cases; forward has a larger gap to cuDNN, particularly for causal attention.

You'll notice that for the forward pass, Noop matches cuDNN closely while Causal lags further behind. This gap highlights how much overhead block-sparse iteration adds compared to FA4's builtin Causal path.

Why causal lags (and how to close the gap)

After some investigation, one of the culprits is work scheduling: if you read through the FA3 code, you'll see the usage of longest processing time first (LPT) scheduling, which FA4 implements for builtin causal but FlexAttention doesn't use. If we manually specify LPT scheduling, the performance looks like this:

With LPT scheduling manually specified, forward sees up to 1.6× speedup at shorter sequences, tapering to ~1.1× at longer sequences. Backward sees minimal difference since the scheduling overhead is amortized differently. We still don't fully match the performance, but we are getting closer.

The LPT schedule works here since we know the specific sparsity pattern is causal and the schedule is optimal for this case. In general, we don't know the pattern ahead of time: block sparsity can be data-dependent, with different rows having different numbers of active KV blocks.

We could rely on CUDA to launch individual output tiles in a load-balanced manner, but then we would miss persistent scheduling gains from overlapping MMA & loads with epilogues and not repeating prologues. This is exactly the problem that Cluster Launch Control (CLC) solves. CLC is a Blackwell feature that enables dynamic work scheduling: instead of statically dividing tiles across SMs at launch time, workers can query for new tiles on-the-fly. When one SM finishes early (because its row had fewer blocks to process), it immediately picks up the next available tile rather than sitting idle. CuTeDSL 4.4 added support for CLC-based persistent scheduling, which lets FlexAttention transparently benefit from better work distribution for block-sparse patterns without requiring users to specify a schedule.

Patterns supported by FlexAttention

FlexAttention is really for patterns that SDPA doesn't support: ALiBi, document masking, sliding window and arbitrary user-defined score modifications.

For these Flex-only patterns on B200:

  • ALiBi: 1.2–2.1× forward speedup, 1.9–2.9× backward speedup
  • Document mask: Up to 2.7× forward, 3× backward at longer sequences
  • Sliding window: 1.4–2.1× forward, 1.8–2.2× backward

Hopper (H200) results

On Hopper GPUs, Flash is consistently faster across all sequence lengths.

For these Flex-only patterns on H200:

  • ALiBi: 1.30–1.54× forward speedup, 1.36–1.65× backward speedup
  • Document mask: 1.41–1.89× forward, 1.48–2.01× backward
  • Sliding window: 1.45–1.65× forward, 1.35–1.52× backward

The gains are present even at shorter sequences (2K), with larger speedups as sequence length increases.

Correctness and benchmark methodology

All benchmark numbers in this post were generated with attention-gym/benchmarks/flex_perf.py.

Correctness

We validate the Flash backend by comparing outputs against an FP32 reference (cast Q/K/V to FP32, run attention, cast back). Upstream test suites continuously exercise these checks:

  • PyTorch Inductor: a broad matrix of score_mod / mask_mod patterns (including captured buffers and views) and Flash-vs-Triton comparisons in test/inductor/test_flex_flash.py.
  • FlashAttention (CuTe): stress tests for mask_mod + block-sparsity across many (seqlen_qseqlen_k) pairs, validating forward and backward against a flex_attention reference in tests/cute/test_mask_mod.py.

Beyond unit tests, we also validated the Flash backend in a real training setting: Llama 3 70B on 64 H100 GPUs with sequence length 8192 using torchtitan. Both runs converge to a final loss of ~3.7 over 1000 training steps:

Limitations

Block size constraints: For paged attention (like the vLLM integration), it's common to align kernel blocks to the page size. Today, the FA4 path is tuned around 128×128 blocks on Hopper and 256×128 on Blackwell (due to q_stage=2), with limited flexibility to change block sizes. As FA4 exposes more robust smaller tile_m/tile_n options, we plan to enable this feature.

Dynamic scalars: Dynamic tensor shapes are fully supported and resolve at runtime. However, scalars captured in score_mod or mask_mod are baked into the compiled kernel. If you have a soft_cap value that changes between calls, each unique value triggers a recompilation:

def tanh_softcap(score, b, h, q_idx, kv_idx):
    return soft_cap * tanh(score / soft_cap)

Backward for captured buffers that require gradients: Not currently supported in the Flash backend. For example, learnable bias tensors:

bias = torch.randn(seq_q, seq_kv, device='cuda', requires_grad=True)
def bias_func(score, b, h, q_idx, kv_idx):
    return score + bias[q_idx, kv_idx]

The Triton backend supports gradients for captured buffers; use it for these cases.

Deterministic backward with block-sparsity: The Flash backend's backward pass is not yet deterministic when block-sparsity is enabled (score-mod-only workloads are deterministic). We're actively working on a fix for this.

Performance limitations:

  • Loads on the KV dimension in forward can stall the pipeline, especially for pointer-chasing patterns (e.g., document masking with per-token metadata) where aux-tensor loads are hard to overlap with compute.
  • Backward with score_mods requiring pre-softmax scores almost always spills registers with current tiling. For example, the gradient of score**2 is 2 * score * grad_score, which requires keeping pre-softmax scores live during the backward pass. TMEM is fully occupied by the main attention accumulators, and current block sizes rarely leave room in SMEM for the S tile, so it stays in registers and spills heavily, causing a noticeable slowdown.

Future work

We're excited about CuTeDSL and the FA4 integration closing the gap between research and production.

On the Flash backend specifically, we're working on support for dynamic scalars captured in score mods without requiring recompilation (e.g., changing a soft_cap value between calls). Gradients for captured buffers will continue to rely on the Triton backend for the foreseeable future. We're also exploring dynamic persistent scheduling to improve work distribution across block-sparse patterns automatically.

While this article is about the FA4 implementation, the Triton implementation remains supported on a much wider range of hardware, and we plan to continue improving both backends.

Thanks

This was a cross-repo collaboration.

The FlashAttention-4 kernel work (CuTeDSL implementation, scheduling, and the extension points needed for score/mask mods and block sparsity) lives upstream in Dao-AILab/flash-attention, while the compiler + integration work (FlexAttention API behavior, Inductor lowering, and CuTeDSL codegen) lives upstream in pytorch/pytorch.

Thanks to the maintainers, reviewers, and contributors in both repos, and to the NVIDIA CUTLASS/CuTeDSL team for building the abstractions that made a JIT-style workflow practical.

  • FlashAttention / FA4 (kernel + extension points): Tri Dao, Ted Zadouri, Reuben Stern, Markus Hoehnerbach, Jay Shah
  • PyTorch / Inductor (lowering + codegen + integration): Markus Hoehnerbach
  • CuTeDSL / CUTLASS: Fung Xie

Further reading / links

]]>
Deploying PyTorch Models to the Micro-Edge with ExecuTorch and Arm https://pytorch.org/blog/deploying-pytorch-models-to-the-micro-edge-with-executorch-and-arm/ Thu, 05 Mar 2026 15:55:23 +0000 https://pytorch.org/?p=47927 The world of AI is expanding beyond the cloud, reaching devices that fit in the palm of your hand. Running PyTorch models on these tiny systems, where memory is measured in kilobytes, requires a new way of thinking. That’s where ExecuTorch, the lightweight runtime for edge inference, bridges the gap between familiar PyTorch workflows and low-power Arm-based microcontrollers, using optimizations such as quantization and graph compilation to make models efficient enough for the edge.

I recently built a Tiny Rock-Paper-Scissors (RPS) demo using PyTorch and ExecuTorch on the Arm Corstone-320 platform. The goal: take a small Convolutional Neural Network (CNN) trained in PyTorch and deploy it all the way to a simulated Arm microcontroller with an Arm Ethos-U NPU (via the Arm Fixed Virtual Platform (FVP)). Here’s what that journey looks like, and why it matters for anyone building at the edge.

Why PyTorch at the Edge?

PyTorch makes model experimentation fast and intuitive, but moving from the flexibility of dynamic graphs to the rigid constraints of embedded hardware isn’t trivial. Most microcontrollers have less than 1 MB of RAM and no operating system, so traditional Python inference is off the table.

ExecuTorch solves this by compiling PyTorch models into a compact, portable format (`.pte`) that runs on devices with minimal compute, power, and memory. During this process, weights and activations are quantized from floating-point to lower-precision integer formats (typically int8), dramatically reducing both memory footprint and compute costs while maintaining model accuracy. The computation graph is also flattened, fused, and optimized, removing redundant operations and enabling smooth execution at the edge. It extends the PyTorch ecosystem all the way down to the smallest Arm Cortex-M and Ethos-U-based systems.

From PyTorch to the Micro-Edge

The great news is, I have built a detailed learning path to guide you through an end-to-end TinyML EdgeAI pipeline.

The Tiny RPS Game

The course’s centerpiece is the Tiny RPS game. It’s a fun and approachable way to learn about TinyML, while showing that PyTorch workflows can scale down just as easily as they scale up. It is a minimal but complete AI workflow which:

  • Generates its own dataset.
  • Trains a CNN in PyTorch.
  • Exports it via ExecuTorch.
  • Deploys it to the FVP, no need for physical hardware.
  • All you need is an x86 Linux host machine or VM running Ubuntu 22.04 or later.

The Pipeline

  1. Model Training in PyTorch

We define and train a compact CNN to classify synthetic images of “rock,” “paper,” and “scissors.” Each class is rendered as a noisy 28×28 grayscale image of its first letter (“R”, “P”, or “S”) to simulate data variation. (See Learning Path for detailed script)

```python

import torch
import torch.nn as nn

class TinyRPS(nn.Module):
    """
    Simple ConvNet:
    [B,1,28,28] -> Conv3x3(16) -> ReLU -> Conv3x3(32) -> ReLU
      -> MaxPool2d(2) -> Conv3x3(64) -> ReLU -> MaxPool2d(2)
      -> flatten -> Linear(128) -> ReLU -> Linear(3)
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 3),
        )

    def forward(self, x):
        x = self.body(x)
        x = self.head(x)
        return x
```

This architecture is compact and Ethos-friendly, ideal for deployment to the micro-edge. Training uses Adam with a small synthetic dataset and achieves over 95% validation accuracy after a few epochs.

  1. Exporting to ExecuTorch

Once trained, the model is exported to an ExecuTorch `.pte` program. This format is optimized for execution without Python on devices running tiny embedded runtimes. (See Learning Path for detailed script)

```python

from executorch import exir
from torch.export import export

def export_to_pte(model: nn.Module, out_path: str, img_size: int) -> None:
    model.eval()
    example = torch.zeros(
        1, 1, img_size, img_size,
        dtype=torch.float32
    )

    # Export with PyTorch’s exporter
    exported = export(model, (example,))
    edge = exir.to_edge(exported)
    prog = edge.to_executorch()

    with open(out_path, "wb") as f:
        f.write(prog.buffer)

    print(f"[export] wrote {out_path}")
```

This step effectively converts your PyTorch computation graph into a static, memory-efficient graph that can run on microcontrollers with minimal overhead.

  1. Deployment on Arm Corstone-320 FVP

The `.pte` file is deployed on the Arm Corstone-320 FVP, a software simulation of a Cortex-M CPU paired with an Ethos-U microNPU. This allows developers to run and validate their model locally before flashing it to real hardware. The RPS game lets you play interactively in the terminal, demonstrating real-time on-device inference.

Lessons Learned

Working on this demo revealed that PyTorch’s flexibility doesn’t have to stop at the data center.  ExecuTorch makes it possible to bring the same familiar PyTorch workflow to IoT sensors, wearables, and embedded devices, enabling privacy-preserving, low-power AI anywhere.

Edge AI may be small in size, but it’s huge in potential.

Try It Yourself

Learning Path: Edge AI with PyTorch & ExecuTorch – Tiny RPS on Arm Target Audience: ML developers and embedded engineers with basic PyTorch experience. Prerequisite: Introduction to TinyML on Arm

Acknowledgements

This learning path was a collaborative effort, and I owe a special thanks to the team that helped bring this course to life, including the valuable contributions of Annie Tallund, Zingo Andersen, George Gekov, Gemma Paris, Adrian Lundell, Madeline Underwood, Mary Bennion, and Fredrik Knutsson.

Explore other help for Edge AI development on low-power, resource-constrained devices using Arm Ethos-U NPUs   

]]>
Quantization-Aware Training in TorchAO (II) https://pytorch.org/blog/quantization-aware-training-in-torchao-ii/ Wed, 04 Mar 2026 17:10:12 +0000 https://pytorch.org/?p=47934 In our previous Quantization-Aware Training (QAT) blog, we introduced the initial QAT flow in TorchAO for large language models targeting edge devices with ExecuTorch. Since then, we extended this flow to also target fast CUDA kernels like the ones in MSLK for fast inference in vLLM, and incorporated this flow into popular fine-tuning frameworks like Unsloth and Axolotl. We also explored more advanced QAT techniques like PARQ for lower bit quantization (prototype):

  • Unsloth integration: Recover up to 66.9% accuracy degradation with INT4 QAT and achieve 1.73x inference speedup compared to BF16. Also check out our notebooks and HuggingFace checkpoints for an end-to-end guide.
  • Axolotl integration: Recover up to 71.6% accuracy degradation with NVFP4 QAT (prototype) and achieve 1.35x inference speedup with 1/4th of the HBM usage compared to BF16 on B200 GPUs.
  • PARQ: (prototype) An alternate optimizer-based QAT technique for lower bit quantization. Achieve on par accuracy with a 3-bit per-row model compared to a 4-bit per-group baseline, while using only ~58% memory footprint and decoding at ~1.57x faster throughput.

Try out our latest QAT flow with a few lines of code!

from torchao.quantization import quantize_, Int4WeightOnlyConfig
from torchao.quantization.qat import QATConfig

# The same config to use for Post-Training Quantization (PTQ)
base_config = Int4WeightOnlyConfig(group_size=32)


# Prepare step: model is now "fake quantized" and ready for training
quantize_(model, QATConfig(base_config, step="prepare")

train(model)

# Convert step: model is now quantized and ready for inference
quantize_(model, QATConfig(base_config, step="convert")

Quantization-Aware Training

One well-known technique to mitigate the accuracy degradation from post-training quantization (PTQ) is QAT, which is an optional fine-tuning step that adapts the model weights towards a representation that is more “aware” that they will be quantized eventually. QAT achieves this by “fake quantizing” weights and optionally activations during training, which means mimicking PTQ numerics as closely as possible and then immediately dequantizing back to high precision values during the forward pass, while leaving the backward pass unchanged.

QAT can also be combined with Low-Rank Adapation (LoRA) to reap the benefits of both worlds: significantly reducing storage and compute requirements during training while mitigating quantization degradation. LoRA is a popular fine-tuning technique that reduces the number of trainable parameters significantly by freezing the original model weights and instead training only a new set of LoRA “adapters” that are a small fraction of the model in terms of size. During training, fake quantization is applied to both the LoRA adapters (A and B in Figure 1) and the frozen weights (and optionally the input activations), since the adapters will eventually be merged back into the weights in the quantized model. This technique has been shown to speed up QAT by 1.89x while reducing the memory required by 36.1%.

Figure 1: Combining QAT + LoRA allows users to mitigate quantization degradation while speeding up training and reducing memory footprint during training. Fake quantization is applied dynamically to the frozen weights, the LoRA adapters, and (optionally) the input activations.

Note that QAT + LoRA differs from QLoRA in that weights (and optionally activations) are fake quantized during training, instead of actually quantized to actual lower bit dtypes (e.g. NF4) and stored in this representation before training. In general, the quantization scheme simulated during QAT should match the actual post-training quantization scheme as much as possible.

As of TorchAO 0.16.0, we support the following dtype combinations:

Weight dtype Input activation dtype TorchAO config
INT4 FP32, BF16 Int4WeightOnlyConfig
INT4 FP8 Float8DynamicActivationInt4WeightConfig
INT4 INT8 Int8DynamicActivationIntxWeightConfig (for edge)
NVFP4 NVFP4 NVFP4DynamicActivationNVFP4WeightConfig (prototype)

Integration with Unsloth

TorchAO’s QAT support is integrated into Unsloth’s fine-tuning workflows for both full fine-tuning (training all model parameters) and LoRA fine-tuning (training only the adapters). In our initial experiments, INT4 weight-only QAT recovered 66.9% accuracy degradation for Gemma3-4B on GPQA and 45.5% for Gemma3-12B on BBH when compared to the quantized baseline without QAT (Figure 2). This translates to raw accuracy improvements of 2.1% and 1.0%, respectively, just by applying fake quantization during fine-tuning (Figure 3). The resulting QAT quantized model can be used as a drop-in replacement for the non-QAT quantized model since the model structure is unchanged, but with superior accuracy.

Figure 2: Unsloth leverages QAT in TorchAO to recover accuracy degradation by up to 66.9% using INT4 QAT + LoRA (source). QAT recovered 45.5% of the 4.5% accuracy lost on Gemma3-12B BBH, 66.9% of the 1.5% accuray lost on Gemma3-4B GPTQ, 36.3% of the 1.11 word perplexity increase for Qwen3-4B WikiText, and 36.0% of the 4.5% accuracy lost on Llama-3.2-1B MMLU Pro.

Figure 3: Unsloth leverages QAT in TorchAO to boost the raw accuracy of quantized and fine-tuned models by up to 2.1% using INT4 QAT + LoRA (source). QAT recovered 2.1% out of 4.5% accuracy lost on Gemma3-12B BBH, 1.0% out of 1.5% accuracy lost on Gemma3-4B GPQA, 2.0% out of 5.8% accuracy lost on Qwen3-4B MMLU Pro, and 0.7% out of 2.0% accuracy lost on Llama-3.2-3B MMLU Pro.

Unsloth users can enable QAT in their fine-tuning workflows by simply specifying the extra qat_scheme flag as follows. For an end-to-end example, check out Unsloth’s free QAT notebooks or the model cards in our HuggingFace checkpoints, which were also fine-tuned with Unsloth.

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma3-12b-it",
    max_seq_length = 2048,
    load_in_16bit = True,
)
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,   
    # We support fp8-int4, fp8-fp8, int4, int8-int4, phone-deployment
    qat_scheme = "int4",
)

Unsloth also leverages TorchAO QAT to deploy models to smartphones like Pixel 8 and iPhone 15 Pro through ExecuTorch, recovering up to 70% accuracy degradation in the process. Users can specify qat_scheme = "phone-deployment"or"int8-int4"for this use case. For more details, please refer to this blog.

Integration with Axolotl

We also integrated TorchAO’s QAT support into Axolotl’s multi-GPU full fine-tuning workflows. Axolotl supports a wide variety of QAT schemes, including INT4 weight-only, FP8 dynamic activations + INT4 weights, and NVFP4 dynamic activations + NVFP4 weights (prototype). 

Thanks to support from Lambda Labs, we were able to demonstrate the effectiveness of QAT on a Lambda Labs Instant Cluster comprising 2xB200 nodes connected with NVLINK intra-node and Infiniband inter-node by training model using QAT of sizes up to 72B parameters. In our initial experiments, NVFP4 QAT recovered significant accuracy degradation when fine-tuning and evaluating across both instruction-following and mathematical reasoning tasks and benchmarks. We recovered up to 63.2% of the accuracy degradation on Gemma3-12B, and 71.6% on Gemma3-27B, translating to +3.2% and +2.3% raw accuracy improvements, respectively (Figures 4 and 5).

Figure 4: Axolotl leverages NVFP4 QAT in TorchAO to recover accuracy degradation by up to 63.2% (+3.2% in absolute accuracy) for Gemma3-12B.

Figure 5: Axolotl leverages NVFP4 QAT in TorchAO to recover accuracy degradation by up to 71.6% (+2.4% in absolute accuracy) for Gemma3-27B across a variety of MMLU tasks.

Applying QAT to Axolotl fine-tuning workflows is straightforward. Simply add the following section in your config file (see these example configs):

# my_qat_workflow.yaml
qat:
  activation_dtype: nvfp4
  weight_dtype: nvfp4
  group_size: 16 # NVFP4 only supports group size of 16 per specification

Then run the following to fine-tune and quantize with the same config. Note that fine-tuning with NVFP4 QAT is only supported on Blackwell GPUs or above for now:

axolotl train my_qat_workflow.yaml
axolotl quantize my_qat_workflow.yaml

For more detailed instructions, please refer to the following documentation:

Piecewise-Affine Regularized Quantization (PARQ)

PARQ is a new algorithm in TorchAO that makes doing QAT with custom lower bit schemes straightforward (see 2025 ICML paper). In particular, it supports QAT with a “stretched elastic quantization” function that spreads the output values more evenly across the quantization grid, in contrast to the usual affine grid. Recent research shows that this is essential for quality when quantizing below 4 bits. In our experiments, the 3-bit per-row model trained with PARQ achieved on par accuracy with the 4-bit per-group baseline, while using only ~58% memory footprint and decoding at ~1.57x faster throughput.

Beyond simplifying experimentation with ultra-low-bit QAT, PARQ integrates seamlessly with ExecuTorch similar to the existing TorchAO QAT flow, enabling an end-to-end path from training to efficient on-device deployment using the PyTorch ecosystem.

TorchAO Integration

Existing TorchAO QAT works by swapping modules in the forward pass. PARQ takes a different approach: it performs quantization directly inside the optimizer’s step function. Users specify which parameters to quantize, along with the quantization functions and granularities, when defining the optimizer’s param_groups. As a result, no changes to the model code are required.

PARQ quantizes weights by applying configuration classes and quantization primitives from TorchAO’s quantize_ API. Depending on the config, it can choose quantization parameters based on an affine grid or the new stretched grid suitable for low-bit QAT. This close integration between PARQ and existing QAT methods ensures a familiar setup process and numerical consistency. Furthermore, it easily composes with dynamic activation quantization from the same API.

In TorchAO, we provide simple APIs for configuring PARQ optimizers to match a wide range of quantization strategies. For example, you might quantize all linear layers to 2 bits while keeping token embeddings at 4 bits. This ability to mix quantizers and bitwidths makes it easy to explore advanced configurations and significantly speeds up experimentation.

import torch
from torchao.prototype.parq.api import QuantConfig, create_optimizer

# Define how to quantize linear weights from `model.parameters()`
def linear_filter_fn(module, fqn):
    return isinstance(module, torch.nn.Linear) and fqn.endswith("weight")

linear_config = QuantConfig(bitwidth=2, group_size=None)
quant_configs_and_filter_fns = [(linear_config, linear_filter_fn)]

# Apply `parq.optim.QuantOptimizer` to quantize in `optimizer.step()`
optimizer = create_optimizer(
    model,
    quant_configs_and_filter_fns,
    base_optimizer_cls=torch.optim.AdamW,
    base_optimizer_kwargs={"weight_decay": 1e-2},
    quant_per_channel=True,
)

Low-bit QAT on Phi-4-mini-instruct and deployment with ExecuTorch

To showcase PARQ in a realistic setting, we performed low-bit fine-tuning on Microsoft’s Phi-4-mini-instruct model, using a variety of quantization schemes and compared these against a 4-bit PTQ baseline for both accuracy and on-device performance.

We release HuggingFace model cards and scripts for reproducing these results on public data:

Linear quantization Embedding quantization Finetuning task ExecuTorch *.pte model size
2bit QAT 2-bit per-row 4-bit per-row, tied with lm_head General conversation 1.13GB
3bit QAT 3-bit per-row 4-bit per-row, tied with lm_head General conversation 1.53GB
4bit QAT 4-bit per row 4-bit per-row, tied with lm_head Grade school math problems 1.93GB
4bit PTQ 4-bit per-group (32), including lm_head 8-bit per-row  N/A 2.78GB

Despite using a far coarser per-row quantization granularity, our 3-bit QAT model performs on par with a 4-bit PTQ model on a range of reasoning benchmarks, as we see in the benchmarks below. In addition, the 4-bit QAT model we optimized for grade school math is nearly 17% more accurate than the 4-bit PTQ baseline on such problems. 

Figure 7a: Accuracies for sub-4 bit models trained with PARQ vs. 4-bit PTQ and base bf16 models.

Figure 7b: Continuation of Figure 7a.

Figure 8: PARQ accuracies on math word problem benchmarks.

With ExecuTorch, we can run the 2, 3, and 4 bit models above on mobile devices. Below we show performance data and screenshots from running the models on iPhone 15 Pro: we see lower bitwidths leading to significant gains in decode speed and memory usage. Indeed, the exported 3-bit model takes only ~58% the memory footprint of the 4-bit model, while decoding at a ~1.57x faster speed, even though it achieves similar accuracy across the benchmarks above. 

Model ExecuTorch backend Memory (GB) Decode speed (tok/sec)
2bit QAT TorchAO lowbit kernels 1.4 27
3bit QAT TorchAO lowbit kernels 1.8 22
4bit QAT TorchAO lowbit kernels 2.2 18
4bit PTQ XNNPACK 3.1 14

Figure 9: Example output of Phi-4-mini-instruct quantized with and without PARQ QAT at bit-widths of 2, 3, and 4, deployed on an iPhone 15 Pro using ExecuTorch.

Looking Ahead

In this blog, we highlighted two new TorchAO QAT integrations with Unsloth and Axolotl and presented PARQ, a novel optimizer-based QAT technique targeting lower bit settings. In the near future, we plan to explore the following directions to continue this work:

  • Reinforcement Learning. Popular RL algorithms like Proximal Policy Optimization (PPO) and  Group Relative Policy Optimization (GRPO) can benefit from on-the-fly quantized rollouts. To maintain true on-policy RL training, however, QAT will be needed to match inference numerics. Exploring how TorchAO QAT can be applied in this setting will be an interesting direction to explore.
  • Leveraging GPU kernels during QAT. There are opportunities to speed up training using custom kernels designed for QAT like the ones in FP-Quant, e.g., performing MXFP4 GEMM during forward and MXFP8 GEMM during backward in this autograd function. For QAT workflows targeting integer quantized dtypes, replicating the integer numerics during training can additionally reduce potential numerical discrepancies in the end-to-end workflow.
  • Further integrations. We also plan to incorporate QAT into newer fine-tuning frameworks like TorchForge and improve our existing integrations by, for example, adding LoRA support to Axolotl QAT.
  • PARQ. We plan to extend support for experimental new optimization algorithms. These are similar to PARQ’s existing algorithm but have more precise convergence guarantees on nonconvex problems.

Acknowledgements

We are deeply grateful to our external collaborators Daniel Han, Michael Han, Datta Nimmaturi (Unsloth), Salman Mohammadi, and Wing Lian (Axolotl) for a fruitful TorchAO QAT integration. We also thank Nick Harvey from Lambda Labs for providing the infrastructure for running our NVFP4 QAT experiments. Finally, we express our gratitude to everyone who provided valuable feedback to the project and this blog, including Driss Guessous, Vasiliy Kuznetsov, Supriya Rao, Mark Saroufim, and Lin Xiao.

]]>
Kubetorch Joins the PyTorch Ecosystem Landscape: A Fast, Pythonic, Fault-Tolerant Interface into Kubernetes for ML https://pytorch.org/blog/kubetorch-joins-the-pytorch-ecosystem-landscape/ Fri, 27 Feb 2026 22:00:05 +0000 https://pytorch.org/?p=47757 Kubetorch enables ML research and development on Kubernetes, across training, inference, RL, evals, data processing, and more, in a deceptively simple and unopinionated package. 

For many teams, Kubernetes is increasingly the compute foundation for ML / AI development, due to its support for arbitrary workloads at scale, a rich open source ecosystem, and workload portability. However, ML/AI development on Kubernetes is unergonomic. Unlike regular software, you cannot test ML code before containerizing and deploying to Kubernetes since you do not locally have 32 H100s (or a single T4). Therefore, you must develop through deployment to Kubernetes. Iterating on a distributed PyTorch training means stopping execution, throwing away state by tearing down pods, rebuilding a Docker image, requeuing for resources, reloading artifacts, and finally resuming training. End-to-end, it takes 10-30 minutes and wrestling with Dockerfiles and YAML to add a print statement. 

To solve this, Kubetorch is a framework to painlessly deploy to Kubernetes. You can take any regular ML program and run it at an arbitrary scale on the remote cluster, calling it on the cluster as if the cluster resources were simply part of your local process pool. Subsequent iteration is extremely fast, with local code changes propagating in seconds as everything is cached in place. And there is zero opinionation introduced, so you can use arbitrary compute, images, resource types, and cluster management. 

To view the PyTorch Ecosystem, see the PyTorch Landscape. Learn more about how projects can join the PyTorch Ecosystem Landscape

How It Works 

Just as PyTorch made it simple to command GPUs, Kubetorch makes it simple to command Kubernetes, with a similar `.to()` API that deploys code to remote and makes it callable.

 import kubetorch as kt
from my_repo.training_main import train

train_compute = kt.Compute(cpus="2", gpus="1")
remote_train = kt.fn(train).to(train_compute)
result = remote_train(lr=0.05, batch_size=4)

In this example, we took our regular training entrypoint and sent it to run on Kubernetes, getting a local callable that behaves identically to the original train function. Subsequent calls are routed to the remote service, with logs and exceptions propagating back, again, as if it were regular local execution. 

You can easily scale your workload, too; assuming your training uses PyTorch DDP, scaling up the training is as simple as modifying the Compute object you send your workload to. 

 import kubetorch as kt
from my_repo.training_main import train

train_compute = kt.Compute(gpus="8").distribute("pytorch", workers=8)
remote_train = kt.fn(train).to(train_compute)
result = remote_train(lr=0.05, batch_size=4)

You also unlock powerful new patterns with this remote execution. For instance, you can deploy a class that enables state across calls. Or, you can catch remote exceptions and decide what to do with that exception without the overall application falling over. 

 import kubetorch as kt
from my_repo import TrainerClass

epochs = 25
batch_size = 32

train_compute = kt.Compute(gpus="8").distribute("pytorch", workers=4)
remote_train = kt.cls(TrainerClass).to(train_compute)

remote_train.load_data()

for epoch in range(epochs):
try:
remote_train.train(epoch, batch_size = batch_size)
except Exception as e:
if "CUDA out of memory" in str(e):
batch_size = batch_size / 2

remote_train.save()

Key Benefits 

Kubetorch is installed with a Python library and a Kubernetes operator, all within your cloud account and VPC. It can be adopted incrementally within an existing ML stack or used as a full replacement, across training, batch processing, online inference, hyperparameter optimization, and pipelining.

  • Run any Python code on Kubernetes at any scale by specifying the required resources, distribution, and scaling directly in code.
  • Iterate on that code in 1–2 seconds with magic caching and hot redeployment.
  • Execute code reproducibly by dispatching work identically to Kubernetes from any environment, including a teammate’s laptop, CI, an orchestrator, or a production application.
  • Handle hardware faults, preemptions, and OOMs programmatically from the driver program, creating robust fault tolerance.
  • Optimize workloads with observability, logging, auto-down, queueing, and more.
  • Orchestrate complex, heterogeneous workloads such as RL post-training, which requires coordinating different compute resources, images, scaling, and distributed communication within a single loop.

Get Started 

Installing and using Kubetorch (GitHub repo) is simple. If you already have a Kubernetes cluster, then all you need to do get started is:

  • Helm install Kubetorch onto the cluster 
  • Pip install the Python client anywhere that you want to use Kubetorch (local machine, CI, orchestrator node)
]]>
Enhancing Multimodal Training and Memory Efficiency with DeepSpeed https://pytorch.org/blog/enhancing-multimodal-training-and-memory-efficiency-with-deepspeed/ Wed, 25 Feb 2026 00:45:25 +0000 https://pytorch.org/?p=47565 Overview

This blog walks through two crucial DeepSpeed updates: (1) a PyTorch-identical backward API that enables efficient training of multimodal, multi-component models (including non-scalar backward calls), and (2) low-precision model training that significantly reduces peak memory, especially.

For multimodal workloads, like combining a vision encoder with an LLM, training loops can become complex and multi-component. The first update introduces a PyTorch-identical backward API that makes writing such loops straightforward, enabling sophisticated parallelism schemes with simple, clean code, while DeepSpeed transparently manages various performance optimizations. As one example, the flexibility of the API enabled disaggregated hybrid parallelism, achieving a 30% speedup for multimodal AI model training while making model development with DeepSpeed feel closer to “vanilla PyTorch”.

Meanwhile, for LLM fine-tuning, a new option to keep all model states (parameters, gradients, and optimizer states) in lower-precision, such as BF16 or FP16, drastically reduces the memory footprint, allowing researchers to train larger models on more constrained hardware. Low-precision training is highly beneficial across a wide range of applications, including supervised fine-tuning (SFT), reinforcement learning (RL), and multimodal training. Our experiment showed 40% peak memory reduction while keeping numerical stability (benchmarking script). The numerical stability is achieved through integration with torch.autocast, which ensures the quality of the model is maintained.

The remainder of this blog will elaborate on how these updates directly facilitate the development of cutting-edge training workloads.

1. PyTorch-identical backward API

DeepSpeed now supports PyTorch’s native backward() syntax while preserving all its optimizations. Traditionally, DeepSpeed’s training loop relied on the engine’s backward API:

loss = model_engine(batch)
model_engine.backward(loss)
model_engine.step()

The engine’s backward API was sufficient for traditional pretraining and fine-tuning pipelines. However, recent complex training pipelines require more flexibility. There were two major limitations:

  1. It only accepted a scalar loss.
  2. You had to call model_engine.backward(loss), rather than using the usual PyTorch loss.backward() style.

Due to these constraints, users could not simply implement patterns that vanilla PyTorch allows. Here are some examples:

# 1. Combine multiple models and losses
output1 = model1(batch1)
output2 = model2(batch2)
loss = criterion(output1, output2)
loss.backward()

# 2. Define a loss function separately from the main model
output = model(batch)
loss = loss_fn(output)
loss.backward()

# 3. Call backward through non-scalar tensors with custom gradients
output = model(batch)
output.backward(grad)

DeepSpeed Engine was able to handle these use cases using internal APIs; however, that required significant code changes and could easily introduce bugs. With the addition of PyTorch-identical backward API, we can now use the same code as native PyTorch while keeping DeepSpeed’s powerful optimizations, including ZeRO and offloading.

One example use case for the PyTorch-identical backward API is disaggregated hybrid parallelism for multimodal models using Ray. In this training pipeline, two Ray Actor groups handle the vision encoder and the LLM separately. On a backward pass, the LLM passes a gradient to the vision encoder, and the vision encoder calls the backward function with that gradient. However, because the gradient is a non-scalar tensor, such a use case wasn’t officially supported by DeepSpeed APIs. The disaggregated hybrid parallelism demonstrates that the flexibility of the backward API combined with DeepSpeed’s optimization and DeepSpeed-Ulysses (highly efficient sequence parallelism), achieves 30% speedup in training.

Below is the pseudo-code for the two models running on different actors. Since they run in different processes, we pass gradients via Ray actor communication. As seen here, the gradient of the vision embedding is a non-scalar tensor. Although this code is identical to the PyTorch API, it will activate various DeepSpeed optimizations based on your configuration.

# Runs on LLM actors
def text_backward_step(self):
# ...
  self.loss.backward()
  return self.vision_embeddings.grad.detach().clone()

# Runs on Vision actors
def vision_backward_step(self, vision_embedding_grad):
  self.vision_output.backward(gradient=vision_embedding_grad)

Check out the repository for the complete training pipeline.

2. Memory-efficient low-precision model states

You can now keep all model states (parameters, gradients, and optimizer states) in BF16 or FP16, significantly reducing memory consumption.

Traditionally, DeepSpeed’s mixed precision keeps FP32 master parameters, gradients, and optimizer states, which is technically safer but memory-intensive. While DeepSpeed has supported torch.autocast via configuration (see the API documentation), the lack of an option to bypass creating FP32 states limited the trainability of large models on constrained hardware. In practice, many training workloads converge stably without FP32 states.

With the low-precision model states option, you can easily skip creating FP32 states and combine the low-precision option with torch.autocast support (see the document and example for configuration details). This combination drastically improves memory efficiency without sacrificing convergence.

{
...
  "zero_optimization": {
    "stage": 3,
    ...
  },
  "bf16": {
    "enabled": true,
    "bf16_master_weights_and_grads": true,
    "bf16_optimizer_states": true
  },
  "torch_autocast": {
    "enabled": true,
    "dtype": "bfloat16"
  }
}

Our example script demonstrates the significant memory savings:

Configuration Allocated Memory Peak Memory Avg Step Time
Baseline (fp32 master) 25.74 GB 31.38 GB 0.6016s
BF16 low-precision (master + opt states) 16.17 GB 18.93 GB 0.6427s

The experiment (7B model, ZeRO3, 4GPUs) demonstrated 40% reduction in peak memory. To verify that BF16 low-precision training maintains numerical stability, we trained for 1000 steps on the Wikitext-103 dataset:

Loss curve comparison

Configuration Final Loss Mean Loss
Baseline (fp32 master) 3.09 2.78
BF16 Low-Precision 3.12 2.90

Related Tests

We continuously test these new APIs in our CI, and you can see various use-case patterns in the tests.

Closing Thoughts

This DeepSpeed update delivers key advancements:

  • Enabling Complex Multimodal Workloads: The new PyTorch-identical backward API enables sophisticated multi-component training loops, such as those required for multimodal models, with simple, clean code. As one example, the PyTorch-identical backward API has enabled a 30% speedup for disaggregated hybrid parallelism.
  • Scaling to Larger Models: Low-precision model states combined with torch.autocast reduce peak memory by up to 40% without sacrificing convergence, allowing you to train larger models with the same hardware.

We are excited to see how you use the new APIs and features described in this blog post in your own training setups, and we welcome feedback and issues on GitHub as you try them out.

]]>