1. Introduction: The Scale Problem

The deep learning community is engaged in an unprecedented scaling race. Model sizes have grown from millions to hundreds of billions of parameters in just a few years, with the largest models now exceeding a trillion parameters. This exponential growth has fundamentally changed how we train neural networks—what once fit comfortably on a single GPU now requires orchestrating thousands of accelerators across multiple nodes.

Understanding distributed training is no longer optional for ML practitioners. Whether you're fine-tuning a 7B parameter model or training a frontier model from scratch, you need to understand the communication primitives, parallelism strategies, and performance trade-offs that make large-scale training possible.

1.1 The Exponential Growth of Model Size

Let's trace the evolution of model scale to understand why distributed training became necessary:

Evolution of Language Model Scale
10M 100M 1B 10B 100B 1T Parameters 2017 2018 2019 2020 2021 2022 2023+ Transformer 65M BERT-Large 340M GPT-2 1.5B GPT-3 175B MT-NLG 530B PaLM 540B GPT-4 era ~1T+ ~10,000× growth in 6 years
Language model parameter counts have grown exponentially, roughly doubling every few months. This growth rate far exceeds Moore's Law and single-GPU memory capacity growth.

This growth isn't just about bragging rights—larger models consistently demonstrate better performance on benchmarks and real-world tasks. The scaling laws discovered by Kaplan et al. and refined by Hoffmann et al. (Chinchilla) show that model performance improves predictably with scale, following power-law relationships:

$$L(N) \propto N^{-\alpha}$$

Where $L$ is the loss, $N$ is the number of parameters, and $\alpha \approx 0.076$ for language models. This means that to halve the loss, you need to increase model size by approximately $2^{1/0.076} \approx 8000×$. The economic incentive to scale is clear, but it creates an equally clear engineering challenge.

1.2 The Memory Wall

The fundamental constraint in training large models is memory. Let's break down exactly where the memory goes during training. For a model with $N$ parameters using mixed-precision training with the Adam optimizer:

Memory Breakdown for Training

Parameters (fp16): $2N$ bytes
Gradients (fp16): $2N$ bytes
Adam optimizer states:
  • Momentum (fp32): $4N$ bytes
  • Variance (fp32): $4N$ bytes
  • Master weights (fp32): $4N$ bytes
Total: $2N + 2N + 4N + 4N + 4N = 16N$ bytes

That's 16 bytes per parameter just for the model state—before we even consider activations! Let's see what this means for real models:

Model Parameters Model State (16N) A100 80GB GPUs Needed
LLaMA-7B 7B 112 GB 2+
LLaMA-13B 13B 208 GB 3+
LLaMA-65B 65B 1.04 TB 13+
GPT-3 175B 2.8 TB 35+
PaLM 540B 8.6 TB 108+

And we haven't even counted activation memory yet! Activations are the intermediate tensors computed during the forward pass and needed for the backward pass. For a transformer with:

The activation memory scales approximately as:

$$\text{Activations} \approx sbh \cdot (34 + 5\frac{as}{h}) \cdot L \text{ bytes (fp16)}$$

For a GPT-3 scale model with $s=2048$, $b=1$, $h=12288$, $a=96$, and $L=96$ layers:

$$\text{Activations} \approx 2048 \times 12288 \times (34 + 5 \times \frac{96 \times 2048}{12288}) \times 96 \approx 275\text{ GB per sample}$$
The Activation Memory Crisis

A single training sample for GPT-3 requires ~275GB just for activations. This is why activation checkpointing (recomputation) is essential—it trades compute for memory by recomputing activations during the backward pass instead of storing them all.

Memory Breakdown: Training a 7B Parameter Model
Params 14GB Grads 14GB Adam Momentum 28GB (fp32) Adam Variance 28GB (fp32) Master Weights 28GB (fp32) Model State: 112GB (16 bytes × 7B parameters) + Activations: ~50-200GB (varies with batch/sequence) Model Weights (fp16) Gradients (fp16) Optimizer State (fp32) Activations (variable)
For a 7B model, the optimizer state alone (84GB) exceeds a single A100's memory. Activations add substantial additional memory that varies with batch size and sequence length.

1.3 Training Bottlenecks: Compute, Memory, and Communication

Distributed training introduces a new dimension of complexity: communication. When we spread training across multiple devices, we must constantly exchange data between them. The three fundamental bottlenecks in distributed training are:

Compute Bound

Bottleneck: GPU arithmetic throughput
Metric: TFLOPS utilization
Solution: Maximize arithmetic intensity, use tensor cores

Memory Bound

Bottleneck: GPU memory capacity
Metric: GB used vs available
Solution: Shard model, checkpoint activations

Communication Bound

Bottleneck: Inter-GPU bandwidth
Metric: GB/s, messages/sec
Solution: Overlap, compress, minimize volume

The interplay between these bottlenecks determines which distributed training strategy is optimal for a given model, hardware configuration, and batch size. This is why understanding performance modeling—which we'll cover in depth—is crucial for efficient distributed training.

Modern AI Cluster Communication Hierarchy
Node 1 (DGX H100) H100 80GB H100 80GB H100 80GB H100 80GB NVLink/NVSwitch 900 GB/s Node 2 (DGX H100) H100 80GB H100 80GB H100 80GB H100 80GB NVLink/NVSwitch 900 GB/s InfiniBand / RoCE Network 400 Gb/s (50 GB/s) per NIC × 8 NICs = 400 GB/s aggregate Bandwidth Hierarchy: Intra-node (NVLink): 900 GB/s Inter-node (IB): 50-400 GB/s Cross-datacenter: 10-100 GB/s ⚠️ Inter-node bandwidth is 2-18× lower than intra-node — communication pattern design is critical!
Modern AI clusters have a hierarchical bandwidth structure. Efficient distributed training strategies exploit this hierarchy by minimizing cross-node communication.

1.4 The Distributed Training Landscape

To address these challenges, the community has developed a rich toolkit of parallelism strategies. Each addresses different bottlenecks and has different trade-offs:

Strategy What's Parallelized Memory Savings Communication Pattern When to Use
Data Parallel Samples across batches None (replicated) AllReduce on gradients Model fits in GPU
FSDP/ZeRO Model state Linear in #GPUs AllGather + ReduceScatter Model state too large
Tensor Parallel Layer operations Linear in TP degree AllReduce + AllGather Layers too large
Pipeline Parallel Layers across GPUs Linear in PP degree Point-to-point Many layers, hide latency
Expert Parallel MoE experts Linear in #experts/GPU All-to-All MoE architectures

In practice, training frontier models requires combining multiple strategies—known as 3D parallelism or even 4D/5D parallelism when including expert parallelism and sequence parallelism. Understanding how these strategies compose and when to use each is the key to efficient large-scale training.

What This Guide Covers

This guide takes a first-principles approach to distributed training. We start with the collective operations that underpin all distributed training—understanding not just what they do, but how they're implemented internally (ring, tree, hierarchical). We then develop a rigorous performance analysis framework using the α-β model that lets us predict and optimize communication costs. Finally, we apply this framework to analyze each parallelism strategy, understanding exactly when each is optimal and how to combine them effectively.

1.5 A Word on Hardware

Throughout this guide, we'll reference specific hardware capabilities. Here's a quick reference for modern training hardware:

NVIDIA H100 SXM

80 GB
HBM3 Memory
• 3.35 TB/s memory bandwidth
• 989 TFLOPS TF32
• 1979 TFLOPS FP16/BF16
• 900 GB/s NVLink

NVIDIA A100 SXM

80 GB
HBM2e Memory
• 2.0 TB/s memory bandwidth
• 312 TFLOPS TF32
• 624 TFLOPS FP16/BF16
• 600 GB/s NVLink

InfiniBand NDR

400 Gb/s
Per Port
• 50 GB/s per direction
• ~1-2μs latency
• RDMA support
• 8 ports per DGX node

Understanding these hardware specifications is essential for performance modeling. Notice the significant gap between intra-node bandwidth (NVLink: 600-900 GB/s) and inter-node bandwidth (InfiniBand: 50-400 GB/s). This asymmetry fundamentally shapes how we design distributed training systems.

2. Collective Operations Fundamentals

Before diving into parallelism strategies, we must understand the fundamental building blocks of distributed communication: collective operations. These are communication patterns that involve all processes in a group, enabling coordinated data exchange. Every distributed training algorithm is ultimately composed of these primitives.

The Message Passing Interface (MPI) standard defines these operations, and libraries like NCCL (NVIDIA Collective Communications Library), Gloo, and others implement them with hardware-specific optimizations. Understanding these operations—and how they're implemented—is essential for reasoning about distributed training performance.

2.1 Point-to-Point Operations: The Foundation

Before collective operations, let's briefly cover point-to-point communication—the simplest form of message passing where one process sends data directly to another:

Python (PyTorch) point_to_point.py
import torch.distributed as dist

# Send operation (on rank 0)
if dist.get_rank() == 0:
    tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda')
    dist.send(tensor, dst=1)

# Receive operation (on rank 1)
if dist.get_rank() == 1:
    tensor = torch.zeros(3, device='cuda')
    dist.recv(tensor, src=0)
    # tensor is now [1.0, 2.0, 3.0]

Point-to-point operations have straightforward performance characteristics:

$$T_{\text{p2p}} = \alpha + \beta \cdot n$$

Where $\alpha$ is the latency (time to initiate communication) and $\beta$ is the inverse bandwidth (time per byte). We'll formalize this model in Section 5.

2.2 Broadcast

Broadcast sends data from one process (the root) to all other processes in the group. After a broadcast, every process has an identical copy of the data.

Broadcast Operation
Before A Rank 0 (root) Rank 1 Rank 2 Rank 3 broadcast After A Rank 0 A Rank 1 A Rank 2 A Rank 3 Data volume sent: n bytes × (P-1) copies (naive) or n bytes (optimal)
Python (PyTorch)
import torch.distributed as dist

# Initialize tensor on all ranks
if dist.get_rank() == 0:
    tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda')
else:
    tensor = torch.zeros(3, device='cuda')

# Broadcast from rank 0 to all ranks
dist.broadcast(tensor, src=0)
# Now all ranks have tensor = [1.0, 2.0, 3.0]
Use Case in Distributed Training

Broadcast is used to distribute model parameters from rank 0 to all workers at initialization, and to synchronize hyperparameters or random seeds across processes.

2.3 Reduce

Reduce combines data from all processes using a reduction operation (sum, max, min, product, etc.) and stores the result on a single root process. It's the inverse of broadcast.

Reduce Operation (Sum)
Before A₀ Rank 0 A₁ Rank 1 A₂ Rank 2 A₃ Rank 3 reduce(sum) After ΣAᵢ (A₀+A₁+A₂+A₃) Rank 0 (root) A₁ Rank 1 A₂ Rank 2 A₃ Rank 3 Total compute: (P-1) × n operations | Only root receives result
Python (PyTorch)
# Each rank has local gradients
local_gradients = compute_gradients(model, batch)

# Reduce sum to rank 0
dist.reduce(local_gradients, dst=0, op=dist.ReduceOp.SUM)

# Only rank 0 has the sum of all gradients
if dist.get_rank() == 0:
    local_gradients /= dist.get_world_size()  # Average

2.4 Scatter

Scatter distributes different chunks of data from the root process to all processes. If the root has data of size $n \times P$, each process receives a chunk of size $n$.

Scatter Operation
Before A₀ A₁ A₂ A₃ Rank 0 (root) R1 R2 R3 scatter After A₀ Rank 0 A₁ Rank 1 A₂ Rank 2 A₃ Rank 3 Total data moved: n × (P-1) bytes from root
Python (PyTorch)
# Root prepares data chunks for each rank
if dist.get_rank() == 0:
    scatter_list = [torch.tensor([i * 10.0], device='cuda') 
                    for i in range(dist.get_world_size())]
else:
    scatter_list = None

# Each rank receives its chunk
output = torch.zeros(1, device='cuda')
dist.scatter(output, scatter_list, src=0)
# Rank 0: [0.0], Rank 1: [10.0], Rank 2: [20.0], ...

2.5 Gather

Gather is the inverse of scatter—it collects data from all processes and concatenates it on the root process.

Gather Operation
Before A₀ Rank 0 A₁ Rank 1 A₂ Rank 2 A₃ Rank 3 gather After A₀ A₁ A₂ A₃ Rank 0 (root) A₁ R1 A₂ R2 A₃ R3 Total data moved: n × (P-1) bytes to root
Python (PyTorch)
# Each rank has local data
local_tensor = torch.tensor([dist.get_rank() * 10.0], device='cuda')

# Gather to rank 0
if dist.get_rank() == 0:
    gather_list = [torch.zeros(1, device='cuda') 
                   for _ in range(dist.get_world_size())]
else:
    gather_list = None

dist.gather(local_tensor, gather_list, dst=0)
# Rank 0: gather_list = [[0.0], [10.0], [20.0], [30.0]]

2.6 Summary of Basic Collectives

These four operations—broadcast, reduce, scatter, and gather—form the foundation. Notice the duality relationships:

Operation Data Flow Root Has Others Have Dual Operation
Broadcast Root → All Original data Copy of root's data Reduce
Reduce All → Root Combined result Original (unchanged) Broadcast
Scatter Root → All (different) One chunk Unique chunk each Gather
Gather All → Root All chunks concatenated Original (unchanged) Scatter
The "All" Variants

The operations we'll cover next—AllReduce, AllGather, ReduceScatter, and AllToAll—are "all" variants where the result is distributed to all processes, not just a root. These are the workhorses of distributed training because they eliminate the root bottleneck.

2.7 Visualizing the Collective Operation Family

The Collective Operations Family Tree
Root-Based Operations "All" Operations (No Root) Broadcast 1 → All (same) Reduce All → 1 (combined) Scatter 1 → All (different) Gather All → 1 (concat) AllReduce All ↔ All (combined) AllGather All → All (concat) ReduceScatter All → All (red+scat) All-to-All All ↔ All (transpose) Key Compositional Relationships • AllReduce = Reduce + Broadcast • AllGather = Gather + Broadcast • ReduceScatter = Reduce + Scatter • AllReduce = ReduceScatter + AllGather ← Most efficient!
All collective operations can be understood as compositions of basic operations. The "All" variants distribute results to all processes, eliminating root bottlenecks.
Why "All" Operations Matter

Root-based operations create a serialization bottleneck—all data must flow through one process. In distributed training with hundreds of GPUs, this becomes catastrophic. The "All" variants (AllReduce, AllGather, ReduceScatter) distribute the work and achieve bandwidth-optimal scaling when implemented correctly. This is why NCCL focuses heavily on optimizing these operations.

3. AllReduce Implementations

AllReduce is the most important collective operation in distributed training. It combines data from all processes using a reduction operation and distributes the result back to everyone. After an AllReduce, every process has an identical copy of the reduced data.

AllReduce Operation (Sum)
Before A₀ Rank 0 A₁ Rank 1 A₂ Rank 2 A₃ Rank 3 AllReduce (sum) After ΣAᵢ (sum) Rank 0 ΣAᵢ (sum) Rank 1 ΣAᵢ (sum) Rank 2 ΣAᵢ (sum) Rank 3 All processes end with identical reduced result: A₀ + A₁ + A₂ + A₃

In distributed training, AllReduce is used to synchronize gradients across all workers. Each worker computes gradients on its local batch, then AllReduce sums them (after which you divide by the world size to get the average gradient).

Python (PyTorch) allreduce_gradients.py
import torch.distributed as dist

# Each rank computes local gradients
loss = model(inputs).sum()
loss.backward()

# AllReduce to sum gradients across all ranks
for param in model.parameters():
    dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
    param.grad /= dist.get_world_size()  # Average

# Now all ranks have identical averaged gradients
optimizer.step()

The implementation of AllReduce dramatically affects training throughput. Let's examine four increasingly sophisticated implementations.

3.1 Naive AllReduce (Reduce + Broadcast)

The simplest AllReduce implementation is just a Reduce to a root followed by a Broadcast from that root. While conceptually simple, this approach has terrible scaling properties.

Naive AllReduce: Reduce then Broadcast
Phase 1: Reduce to Root A₀ R0 A₁ R1 A₂ R2 A₃ R3 ΣAᵢ R0 (root) Bottleneck! (P-1) messages to one node Phase 2: Broadcast from Root ΣAᵢ R0 ΣAᵢ R0 ΣAᵢ R1 ΣAᵢ R2 ΣAᵢ R3 Bottleneck! (P-1) messages from one node Naive AllReduce Complexity Time: T = 2(P-1)α + 2(P-1)nβ Root bandwidth saturated — does NOT scale with P!
Naive AllReduce Cost Analysis

Using the α-β model (latency α, inverse bandwidth β):

Reduce phase: Root receives from (P-1) processes: $(P-1)\alpha + (P-1)n\beta$
Broadcast phase: Root sends to (P-1) processes: $(P-1)\alpha + (P-1)n\beta$

Total: $T_{\text{naive}} = 2(P-1)\alpha + 2(P-1)n\beta$

The $2(P-1)n\beta$ bandwidth term is catastrophic—root must handle all data volume!

3.2 Tree-Based AllReduce

Tree-based algorithms reduce the bottleneck at the root by organizing processes into a tree structure. In a binary tree reduction, each node combines data from two children, then passes the result to its parent.

Binary Tree Reduce (Logarithmic Latency)
A₀ A₁ A₂ A₃ A₄ A₅ A₆ A₇ Step 1 Step 2 Step 3 A₀+A₁ A₂+A₃ A₄+A₅ A₆+A₇ A₀+...+A₃ A₄+...+A₇ Σ Aᵢ (all) Tree Reduce Latency: log₂(P) steps But bandwidth: still (P-1)nβ total!
Tree-based reduce completes in log₂(P) steps, improving latency. However, the full data volume still flows through the root, so bandwidth is not improved.
Tree AllReduce Complexity

Tree Reduce: $T = \log_2(P) \cdot (\alpha + n\beta)$
Tree Broadcast: $T = \log_2(P) \cdot (\alpha + n\beta)$
Total Tree AllReduce: $T_{\text{tree}} = 2\log_2(P)\alpha + 2\log_2(P)n\beta$

Better latency scaling (logarithmic), but the bandwidth term $2\log_2(P)n\beta$ is still suboptimal—we're sending the full data at each level.

3.3 Ring AllReduce

Ring AllReduce is the breakthrough algorithm that achieves bandwidth-optimal scaling. Instead of aggregating all data at a root, it organizes processes in a ring and pipelines the communication so that each link carries only $n/P$ data at a time.

Ring AllReduce is implemented as two phases:

  1. ReduceScatter phase: Each process ends up with 1/P of the fully reduced data
  2. AllGather phase: Distribute the reduced chunks so everyone has the full result
Ring AllReduce: ReduceScatter Phase
Initial State P0 P1 P2 P3 Chunk 0 Chunk 1 Chunk 2 Chunk 3 Step 1 P0 P1 P2 P3 Each receives 1 chunk, adds to local chunk After (P-1) Steps Σ P0 Σ P1 Σ P2 Σ P3 ReduceScatter complete! Each has 1/P of result Final Σ Σ Σ Σ All ranks identical! Ring AllReduce Key Insight Each step: every process sends exactly n/P bytes. After 2(P-1) steps: all data exchanged. Total bandwidth per link: 2n(P-1)/P ≈ 2n bytes — independent of P!
Ring AllReduce Complexity (Bandwidth Optimal!)

ReduceScatter phase: $(P-1)$ steps, each sending $n/P$ bytes
$T_{\text{reduce-scatter}} = (P-1)\alpha + (P-1) \cdot \frac{n}{P} \cdot \beta = (P-1)\alpha + \frac{(P-1)n}{P}\beta$

AllGather phase: $(P-1)$ steps, each sending $n/P$ bytes
$T_{\text{all-gather}} = (P-1)\alpha + \frac{(P-1)n}{P}\beta$

Total Ring AllReduce:
$$T_{\text{ring}} = 2(P-1)\alpha + \frac{2(P-1)n}{P}\beta \approx 2(P-1)\alpha + 2n\beta$$

The bandwidth term is $\frac{2(P-1)}{P}n\beta \approx 2n\beta$ — constant regardless of P!

Python ring_allreduce_pseudocode.py
def ring_allreduce(tensor, rank, world_size):
    """Ring AllReduce: ReduceScatter + AllGather"""
    chunk_size = tensor.numel() // world_size
    chunks = tensor.chunk(world_size)
    
    # Phase 1: ReduceScatter
    for step in range(world_size - 1):
        send_idx = (rank - step) % world_size
        recv_idx = (rank - step - 1) % world_size
        
        # Send chunk[send_idx] to next rank in ring
        send_to_next = chunks[send_idx].clone()
        recv_from_prev = torch.zeros_like(chunks[recv_idx])
        
        # Simultaneous send/recv
        all_to_all([send_to_next], [recv_from_prev], 
                   src=(rank - 1) % world_size,
                   dst=(rank + 1) % world_size)
        
        # Accumulate received chunk
        chunks[recv_idx] += recv_from_prev
    
    # After ReduceScatter: each rank has one fully reduced chunk
    
    # Phase 2: AllGather
    for step in range(world_size - 1):
        send_idx = (rank - step + 1) % world_size
        recv_idx = (rank - step) % world_size
        
        # Same ring pattern, but just copy (no reduction)
        all_to_all([chunks[send_idx]], [chunks[recv_idx]], ...)
    
    return torch.cat(chunks)

3.4 Recursive Halving-Doubling (Rabenseifner's Algorithm)

For small messages where latency dominates, the recursive halving-doubling algorithm (also known as Rabenseifner's algorithm) achieves better latency than ring while maintaining bandwidth efficiency. It's optimal when $P$ is a power of 2.

Recursive Halving-Doubling AllReduce
Phase 1: Recursive Halving (ReduceScatter) Step 1 (dist=4) P0 P1 P2 P3 P4 P5 P6 P7 Exchange half of data with partner at distance 4 Step 2 (dist=2) Step 3 (dist=1) ReduceScatter Complete Each process has 1/P of the fully reduced result Phase 2: Recursive Doubling (AllGather) Step 1 (dist=1) Step 2 (dist=2) Step 3 (dist=4) AllReduce Complete! All processes have identical full result Recursive Halving-Doubling Complexity T = 2·log₂(P)·α + 2·((P-1)/P)·n·β ✓ Bandwidth optimal (same as ring) ✓ Latency: log₂(P) vs 2(P-1) for ring — better for small messages / many processes!

3.5 Comparison: When to Use Each Algorithm

Algorithm Latency Term Bandwidth Term Best For
Naive (Reduce+Bcast) $2(P-1)\alpha$ $2(P-1)n\beta$ Never use at scale
Tree $2\log_2(P)\alpha$ $2\log_2(P)n\beta$ Very small messages
Ring $2(P-1)\alpha$ $\frac{2(P-1)}{P}n\beta \approx 2n\beta$ Large messages
Recursive H-D $2\log_2(P)\alpha$ $\frac{2(P-1)}{P}n\beta \approx 2n\beta$ Medium messages, power-of-2 P
NCCL's Adaptive Strategy

NCCL automatically selects the best algorithm based on message size, number of processes, and hardware topology. For large gradient tensors (MBs to GBs) typical in deep learning, ring-based algorithms dominate. For small control messages or when P is large relative to message size, tree or recursive halving-doubling may be chosen.

3.6 Hierarchical AllReduce

Modern clusters have non-uniform communication costs—intra-node (NVLink) is much faster than inter-node (InfiniBand). Hierarchical AllReduce exploits this by performing local reductions within nodes first, then a global AllReduce across nodes.

Hierarchical AllReduce: Exploiting Network Topology
Node 1 G0 G1 G2 G3 NVLink (900 GB/s) Node 2 G4 G5 G6 G7 NVLink (900 GB/s) Node 3 G8 G9 G10 G11 NVLink (900 GB/s) InfiniBand Network (50 GB/s per link) Step 1: Local Reduce Within each node: AllReduce over NVLink Time: 2(GPUs/node - 1)α + 2n·β_NVLink Step 2: Global Reduce One GPU per node: AllReduce over InfiniBand Time: 2(Nodes - 1)α + 2n·β_IB Step 3: Local Broadcast Within each node: Broadcast over NVLink Time: log₂(GPUs/node)·α + n·β_NVLink
Hierarchical AllReduce minimizes slow inter-node communication by doing local reductions first. Only one representative per node participates in the global AllReduce.
Why Hierarchical Wins

Consider 8 nodes × 8 GPUs = 64 GPUs total. Flat ring AllReduce: 2×63 steps over mixed fast/slow links. Hierarchical: 2×7 fast NVLink steps + 2×7 slow IB steps + 1 broadcast. The slow IB link sees only N-1=7 steps instead of ~63, dramatically reducing the impact of the slower interconnect.

4. AllGather, ReduceScatter, and All-to-All

Beyond AllReduce, three other "all" collective operations are essential for distributed training. Each has distinct communication patterns and use cases in different parallelism strategies.

4.1 AllGather

AllGather collects data from all processes and distributes the concatenated result to everyone. If each process starts with $n$ bytes, every process ends with $n \times P$ bytes containing all the data.

AllGather Operation
Before A₀ Rank 0 A₁ Rank 1 A₂ Rank 2 A₃ Rank 3 AllGather After Rank 0 Rank 1 Rank 2 Rank 3 Input: n bytes/rank → Output: n×P bytes/rank (all identical, concatenated)
Python (PyTorch) allgather_example.py
import torch.distributed as dist

# Each rank has a local tensor (e.g., model shard)
local_tensor = torch.randn(1000, device='cuda')  # n = 1000 elements

# Prepare output list for gathered tensors
world_size = dist.get_world_size()
gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]

# AllGather: collect all shards
dist.all_gather(gathered, local_tensor)

# Result: list of P tensors, each 1000 elements
full_tensor = torch.cat(gathered)  # Shape: [P * 1000]

# Alternative: all_gather_into_tensor (more efficient)
output = torch.zeros(world_size * 1000, device='cuda')
dist.all_gather_into_tensor(output, local_tensor)  # Direct into contiguous buffer
AllGather Use Cases

FSDP/ZeRO-3: Gather sharded model weights before forward pass
Tensor Parallelism: Gather partial activations after column-parallel linear layers
Sequence Parallelism: Reconstruct full sequence after parallel attention

Ring AllGather Implementation

Like Ring AllReduce, Ring AllGather achieves bandwidth-optimal performance by pipelining communication around a ring. Each step, processes send their chunk to the next neighbor and receive from the previous neighbor.

Ring AllGather (P-1 Steps)
Initial A₀ A₁ A₂ A₃ Step 1 Step 2 After P-1 Steps P0: [A₀,A₁,A₂,A₃] Ring AllGather Complexity T = (P-1)α + ((P-1)/P)·n·P·β = (P-1)α + (P-1)nβ ≈ (P-1)α + nPβ

4.2 ReduceScatter

ReduceScatter is the inverse of AllGather: it reduces data from all processes and distributes different chunks of the result to each process. If each process starts with $n \times P$ bytes, every process ends with a different $n$-byte chunk of the reduced result.

ReduceScatter Operation (Sum)
Before Rank 0 a₀ b₀ c₀ d₀ Rank 1 a₁ b₁ c₁ d₁ Rank 2 a₂ b₂ c₂ d₂ Rank 3 a₃ b₃ c₃ d₃ ReduceScatter (sum) After Σaᵢ Rank 0 Σbᵢ Rank 1 Σcᵢ Rank 2 Σdᵢ Rank 3 Input: n×P bytes/rank → Output: n bytes/rank (each has different reduced chunk) Key: AllReduce = ReduceScatter + AllGather (most efficient decomposition!)
Python (PyTorch) reduce_scatter_example.py
import torch.distributed as dist

# Each rank has gradients for ALL parameters (before sharding)
full_gradients = torch.randn(4000, device='cuda')  # n*P elements

# Output: each rank gets 1/P of the reduced gradients
world_size = dist.get_world_size()
output_size = full_gradients.size(0) // world_size
reduced_shard = torch.zeros(output_size, device='cuda')

# ReduceScatter: sum gradients across ranks, distribute shards
dist.reduce_scatter_tensor(reduced_shard, full_gradients, op=dist.ReduceOp.SUM)

# Rank 0 gets sum of elements [0:1000] from all ranks
# Rank 1 gets sum of elements [1000:2000] from all ranks
# etc.
ReduceScatter Use Cases

FSDP/ZeRO-2+3: After backward pass, reduce gradients and scatter to shard owners
Ring AllReduce: First phase—produces the partially reduced chunks
Tensor Parallelism: After row-parallel linear, reduce-scatter activations

4.3 All-to-All (AllToAll)

All-to-All is the most general collective: each process sends different data to every other process. It's essentially a distributed transpose operation. Process $i$ sends chunk $j$ to process $j$, and receives chunk $i$ from process $j$.

All-to-All Operation (Distributed Transpose)
Before (Row Distribution) A₀₀ A₀₁ A₀₂ A₀₃ P0 A₁₀ A₁₁ A₁₂ A₁₃ P1 A₂₀ A₂₁ A₂₂ A₂₃ P2 A₃₀ A₃₁ A₃₂ A₃₃ P3 All-to-All (transpose) After (Column Distribution) A₀₀ A₁₀ A₂₀ A₃₀ P0 A₀₁ A₁₁ A₂₁ A₃₁ P1 A₀₂ A₁₂ A₂₂ A₃₂ P2 A₀₃ A₁₃ A₂₃ A₃₃ P3 Each process sends chunk j to process j, receives chunk i from process i → distributed transpose!
Python (PyTorch) alltoall_example.py
import torch.distributed as dist

# Each rank has P chunks to send (one to each rank)
world_size = dist.get_world_size()
rank = dist.get_rank()

# Input: chunks destined for each rank
send_chunks = [torch.randn(100, device='cuda') for _ in range(world_size)]

# Output: chunks received from each rank
recv_chunks = [torch.zeros(100, device='cuda') for _ in range(world_size)]

# All-to-All: exchange chunks
dist.all_to_all(recv_chunks, send_chunks)

# recv_chunks[j] now contains what rank j sent to us (their chunk for us)

# More efficient: all_to_all_single for contiguous tensors
input_tensor = torch.randn(world_size * 100, device='cuda')  # [P * chunk_size]
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor, input_tensor)
All-to-All Use Cases

Expert Parallelism (MoE): Route tokens to experts across GPUs—each token goes to its assigned expert, and All-to-All redistributes them
Sequence Parallelism: Transpose between sequence-sharded and head-sharded layouts in attention
FFT/Convolutions: Distributed FFT requires All-to-All for global transpose

4.4 Collective Operations Complexity Summary

Here's a comprehensive table summarizing the time complexity of all collective operations using optimal (bandwidth-efficient) algorithms:

Operation Latency Term Bandwidth Term Input/Output Size Algorithm
Broadcast $\log_2(P) \cdot \alpha$ $n\beta$ n → n (all) Binomial tree
Reduce $\log_2(P) \cdot \alpha$ $n\beta$ n (all) → n (root) Binomial tree
Scatter $\log_2(P) \cdot \alpha$ $\frac{(P-1)n}{P}\beta$ nP (root) → n (each) Binomial tree
Gather $\log_2(P) \cdot \alpha$ $\frac{(P-1)n}{P}\beta$ n (each) → nP (root) Binomial tree
AllReduce $2\log_2(P) \cdot \alpha$ $\frac{2(P-1)n}{P}\beta$ n (each) → n (all, same) Recursive HD / Ring
AllGather $\log_2(P) \cdot \alpha$ $\frac{(P-1)n}{P} \cdot P \cdot \beta$ n (each) → nP (all) Ring / Recursive
ReduceScatter $\log_2(P) \cdot \alpha$ $\frac{(P-1)n}{P}\beta$ nP (each) → n (each, different) Ring / Recursive
All-to-All $(P-1) \cdot \alpha$ $\frac{(P-1)n}{P}\beta$ nP (each) → nP (each, shuffled) Direct / Pairwise
Understanding the Bandwidth Terms

The bandwidth term represents how much data flows through each link in the optimal case. Notice these key relationships:

AllReduce bandwidth = 2 × ReduceScatter bandwidth (because AllReduce = ReduceScatter + AllGather)
AllGather output is P× larger than input, so its effective bandwidth use is $(P-1)n\beta$ total
All-to-All has the worst latency $(P-1)\alpha$ because in the worst case every process must directly communicate with every other process

Bandwidth Optimality: Why Ring/Recursive Algorithms Win
Naive AllReduce (Root Bottleneck) Root Root receives (P-1)×n data! Bandwidth: O(Pn) — does NOT scale Ring AllReduce (Balanced) Each link carries n/P per step Bandwidth: O(n) — scales perfectly! Scaling Comparison (gradient size n, P GPUs) Naive: 2(P-1)nβ — grows with P! Ring: 2nβ — constant regardless of P!
Ring and recursive halving-doubling algorithms achieve bandwidth optimality by distributing communication load evenly. Each link carries the same amount of data regardless of cluster size.

4.5 Choosing the Right Collective

Different distributed training strategies use different collective operations:

Training Strategy Primary Collectives Communication Pattern
Data Parallel (DDP) AllReduce Synchronize gradients after backward
FSDP/ZeRO-3 AllGather, ReduceScatter Gather weights before forward, scatter gradients after backward
Tensor Parallel AllReduce, AllGather Reduce activations after parallel matmuls
Pipeline Parallel Point-to-Point (Send/Recv) Pass activations between pipeline stages
Expert Parallel (MoE) All-to-All Route tokens to/from experts
Sequence Parallel AllGather, ReduceScatter, All-to-All Redistribute sequence chunks for attention
Communication Volume Matters

When analyzing distributed training efficiency, always consider the total communication volume per iteration. For a model with $N$ parameters:

DDP: 2N bytes (AllReduce gradients once)
FSDP: 3N bytes per layer × num_layers (AllGather forward + AllGather backward + ReduceScatter gradients)
Tensor Parallel: Proportional to activation size, not model size

Understanding these patterns is key to optimizing distributed training configurations!

5. The α-β Performance Model

To analyze and optimize collective communication, we need a principled model. The α-β model (also called the postal model or Hockney model) decomposes communication time into two fundamental components: a fixed startup cost and a data-dependent transfer cost.

5.1 Model Fundamentals

The time to send a message of $n$ bytes between two processes is:

$$T(n) = \alpha + n \cdot \beta$$

Where:

α-β Model: Communication Time Breakdown
Message Size (n) Time α T(n) = α + nβ Latency-Bound (α dominates) Bandwidth-Bound (nβ dominates) α slope = β = 1/bandwidth Typical Values NVLink: α ≈ 1-5 μs InfiniBand: α ≈ 1-2 μs Ethernet: α ≈ 50-100 μs
Converting β to Bandwidth

If a link has bandwidth $B = 450$ GB/s (like NVLink 4.0):

$\beta = \frac{1}{B} = \frac{1}{450 \times 10^9 \text{ B/s}} = 2.22 \times 10^{-12} \text{ s/byte} = 2.22 \text{ ps/byte}$

For a 1 GB message: $T = \alpha + 1 \times 10^9 \times 2.22 \times 10^{-12} = \alpha + 0.00222\text{s} \approx \alpha + 2.2\text{ms}$

5.2 Extending to Collective Operations

For collective operations involving $P$ processes, the model extends to account for multiple communication steps and the algorithm's structure:

$$T_{\text{collective}}(n, P) = k(P) \cdot \alpha + m(n, P) \cdot \beta$$

Where:

α-β Model for Common Collectives
Operation k(P) · α m(n,P) · β Total Time Broadcast ⌈log₂P⌉ · α n · β ⌈log₂P⌉α + nβ Reduce ⌈log₂P⌉ · α n · β ⌈log₂P⌉α + nβ AllReduce 2⌈log₂P⌉ · α 2 · (P-1)/P · n · β 2⌈log₂P⌉α + 2(P-1)n/P · β AllGather ⌈log₂P⌉ · α (P-1) · n · β ⌈log₂P⌉α + (P-1)nβ ReduceScatter ⌈log₂P⌉ · α (P-1)/P · n · β ⌈log₂P⌉α + (P-1)n/P · β All-to-All (P-1) · α (P-1)/P · n · β (P-1)α + (P-1)n/P · β Key Insight: Bandwidth-Optimal Algorithms Ring/Recursive algorithms achieve the theoretical lower bound: each link carries ≈ n·(P-1)/P bytes

5.3 The Crossover Point

A critical question: when does latency matter vs. bandwidth? The crossover point $n^*$ is where both terms contribute equally:

$$n^* = \frac{\alpha}{\beta} = \alpha \cdot B$$

For messages smaller than $n^*$, latency dominates. For larger messages, bandwidth dominates.

Crossover Points for Different Interconnects
Message Size (log scale) Time (log scale) 1 KB 10 KB 100 KB 1 MB 10 MB 100 MB n* ≈ 2-5 KB NVLink n* ≈ 200-500 KB IB HDR n* ≈ 5-10 MB 100GbE Latency-Bound Many small messages Bandwidth-Bound Few large messages
The crossover point n* varies dramatically between interconnects. NVLink's high bandwidth and low latency means even small messages become bandwidth-bound quickly. Ethernet requires much larger messages before bandwidth matters.
Python calculate_crossover.py
def calculate_crossover(alpha_us, bandwidth_gbps):
    """
    Calculate the crossover point where latency = bandwidth term.
    
    Args:
        alpha_us: Latency in microseconds
        bandwidth_gbps: Bandwidth in GB/s
    
    Returns:
        Crossover point in bytes
    """
    alpha_s = alpha_us * 1e-6  # Convert to seconds
    beta = 1 / (bandwidth_gbps * 1e9)  # seconds per byte
    n_star = alpha_s / beta  # crossover in bytes
    return n_star

# Example calculations
interconnects = {
    "NVLink 4.0": (1.0, 450),      # 1 μs latency, 450 GB/s
    "InfiniBand HDR": (1.5, 200),  # 1.5 μs, 200 GB/s
    "100GbE RoCE": (5.0, 12.5),    # 5 μs, 12.5 GB/s
    "TCP/IP Ethernet": (50, 12.5), # 50 μs (with kernel), 12.5 GB/s
}

print("Interconnect Crossover Points:")
print("-" * 45)
for name, (alpha, bw) in interconnects.items():
    n_star = calculate_crossover(alpha, bw)
    if n_star < 1024:
        print(f"{name:20}: n* = {n_star:8.0f} B")
    elif n_star < 1024**2:
        print(f"{name:20}: n* = {n_star/1024:8.1f} KB")
    else:
        print(f"{name:20}: n* = {n_star/(1024**2):8.1f} MB")

# Output:
# Interconnect Crossover Points:
# ---------------------------------------------
# NVLink 4.0          : n* =    450.0 KB
# InfiniBand HDR      : n* =    300.0 KB
# 100GbE RoCE         : n* =     62.5 KB
# TCP/IP Ethernet     : n* =    625.0 KB

5.4 Practical Implications for Distributed Training

The α-β model has profound implications for how we design distributed training systems:

Optimization Strategies Based on Message Size
Latency-Bound (n < n*) Optimization Strategies: Batch small messages together Use async operations to hide latency Reduce # of collectives (fuse ops) Use tree-based algorithms (low latency) Bandwidth-Bound (n > n*) Optimization Strategies: Use ring/bandwidth-optimal algorithms Compress gradients (reduces n) Pipeline communication with compute Use FP16/BF16 (halves β term) Real-World Training Scenarios ResNet-50 (25M params) Gradients: ~100 MB Per-layer: 1-10 MB avg → Borderline case Latency matters for small layers! GPT-3 (175B params) Gradients: ~350 GB (FP16) Per-layer: 100 MB - 10 GB → Clearly bandwidth-bound Ring algorithms essential! Tensor Parallel Attention Activations: batch × seq × hidden e.g., 32 × 4096 × 12288 × 2B = 3 GB per AllReduce! Bandwidth is everything

5.5 Extended α-β-γ Model

For operations that involve computation (like Reduce operations), we can extend the model to include a computation term:

$$T(n, P) = k(P) \cdot \alpha + m(n, P) \cdot \beta + c(n, P) \cdot \gamma$$

Where γ (gamma) is the time to perform one computational operation (e.g., one floating-point addition). For modern GPUs, γ is typically negligible because:

When γ Matters

The computation term γ becomes significant in these cases:

CPU-based training: CPU arithmetic is much slower than network bandwidth
Custom reduction operations: Complex user-defined reductions
Compression/decompression: Gradient compression trades γ for reduced β

5.6 Measuring α and β in Practice

To optimize your distributed training, you need to know the actual α and β values for your hardware. Here's how to measure them:

Python (PyTorch) measure_alpha_beta.py
import torch
import torch.distributed as dist
import time
import numpy as np

def measure_allreduce_alpha_beta(sizes_bytes, num_warmup=5, num_iters=20):
    """
    Measure α and β for AllReduce by fitting T = α + n*β to timing data.
    
    Run with: torchrun --nproc_per_node=N measure_alpha_beta.py
    """
    times = []
    
    for size in sizes_bytes:
        numel = size // 4  # FP32 elements
        tensor = torch.randn(numel, device='cuda')
        
        # Warmup
        for _ in range(num_warmup):
            dist.all_reduce(tensor)
            torch.cuda.synchronize()
        
        # Timed runs
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(num_iters):
            dist.all_reduce(tensor)
        torch.cuda.synchronize()
        elapsed = (time.perf_counter() - start) / num_iters
        
        times.append(elapsed)
    
    # Linear regression: T = α + n*β
    # For AllReduce: T = 2*log2(P)*α + 2*(P-1)/P * n * β
    sizes = np.array(sizes_bytes)
    times = np.array(times)
    
    # Simple linear fit (ignoring the exact formula for illustration)
    # slope = β, intercept ≈ α
    coeffs = np.polyfit(sizes, times, 1)
    beta = coeffs[0]
    alpha = coeffs[1]
    
    bandwidth_gbps = 1 / beta / 1e9
    
    return alpha, beta, bandwidth_gbps

if __name__ == "__main__":
    dist.init_process_group()
    
    # Test various message sizes from 1KB to 1GB
    sizes = [1024 * (2 ** i) for i in range(20)]  # 1KB to 1GB
    
    alpha, beta, bw = measure_allreduce_alpha_beta(sizes)
    
    if dist.get_rank() == 0:
        print(f"Measured Parameters:")
        print(f"  α (latency)   = {alpha*1e6:.2f} μs")
        print(f"  β (inv. bw)   = {beta*1e12:.2f} ps/byte")
        print(f"  Bandwidth     = {bw:.1f} GB/s")
        print(f"  Crossover n*  = {alpha/beta/1024:.1f} KB")
    
    dist.destroy_process_group()
Measured vs Theoretical Performance
Message Size (bytes) Time (μs) 1K 10K 100K 1M 10M 100M 0 100 200 300 Theoretical T = α + nβ Measured
Real measurements typically follow the theoretical α-β model closely for large messages, with some variance for small messages due to system jitter and scheduling overhead.
Hidden Overheads

The simple α-β model doesn't capture everything. Real systems have additional overheads:

GPU kernel launch: ~10-50 μs per collective call
Memory allocation: First call may allocate buffers
Synchronization: Barrier costs at collective boundaries
Congestion: Shared network links reduce effective bandwidth

Always measure on your actual hardware configuration!

6. The Communication Roofline Model

The classic Roofline Model relates computational performance to memory bandwidth through arithmetic intensity. We can adapt this powerful framework to analyze distributed training by creating a Communication Roofline that relates training throughput to network bandwidth through the computation-to-communication ratio.

6.1 From Compute Roofline to Communication Roofline

Recall the traditional roofline model: performance is bounded by either peak compute or memory bandwidth, depending on arithmetic intensity (FLOPS/byte). For distributed training, we create an analogous model:

Communication Roofline: Adapting the Classic Model
Classic Roofline (Compute) Arithmetic Intensity (FLOPS/byte) TFLOPS Memory BW Peak Compute Ridge Adapt Communication Roofline Comp-to-Comm Ratio (FLOPS/byte) Throughput Network BW Peak GPU Ridge Classic Roofline Communication Roofline Y-axis: Compute throughput (FLOPS) Y-axis: Training throughput (samples/s) X-axis: Arithmetic intensity (FLOPS/byte) X-axis: Comp-to-comm ratio (FLOPS/byte) Ceiling: Peak FLOPS or Memory BW Ceiling: Peak GPU or Network BW

6.2 Defining the Computation-to-Communication Ratio

For distributed training, we define the computation-to-communication ratio (CCR) as the ratio of computation performed to bytes communicated:

$$\text{CCR} = \frac{\text{FLOPs per iteration}}{\text{Bytes communicated per iteration}}$$

For standard data-parallel training with gradient AllReduce:

$$\text{CCR}_{\text{DDP}} = \frac{\text{FLOPs}_{\text{forward}} + \text{FLOPs}_{\text{backward}}}{2 \cdot N \cdot \text{sizeof(dtype)}}$$

Where $N$ is the number of model parameters and the factor of 2 accounts for AllReduce sending/receiving approximately $2N$ bytes (ring algorithm).

Python calculate_ccr.py
def calculate_ccr_ddp(model_params, batch_size, seq_len, dtype_bytes=2):
    """
    Calculate computation-to-communication ratio for DDP training.
    
    For transformers: FLOPs ≈ 6 * N * tokens (forward + backward)
    Communication: 2 * N * dtype_bytes (AllReduce gradients)
    """
    tokens = batch_size * seq_len
    flops = 6 * model_params * tokens  # Approximate transformer FLOPs
    comm_bytes = 2 * model_params * dtype_bytes  # AllReduce volume
    
    ccr = flops / comm_bytes
    return ccr

# Example: GPT-3 175B with different batch sizes
model_params = 175e9
seq_len = 2048

print("GPT-3 175B CCR Analysis:")
print("-" * 50)
for batch_size in [1, 4, 16, 64, 256]:
    ccr = calculate_ccr_ddp(model_params, batch_size, seq_len)
    print(f"Batch size {batch_size:4}: CCR = {ccr:8.1f} FLOPS/byte")

# Output:
# GPT-3 175B CCR Analysis:
# --------------------------------------------------
# Batch size    1: CCR =   3072.0 FLOPS/byte
# Batch size    4: CCR =  12288.0 FLOPS/byte
# Batch size   16: CCR =  49152.0 FLOPS/byte
# Batch size   64: CCR = 196608.0 FLOPS/byte
# Batch size  256: CCR = 786432.0 FLOPS/byte

6.3 The Communication Roofline Diagram

Now we can construct the full communication roofline. The y-axis shows achieved throughput (e.g., samples/second), and we have two ceilings:

Communication Roofline for Distributed Training
Computation-to-Communication Ratio (FLOPS/byte) — log scale Training Throughput (samples/s) — log scale 10 100 1K 10K 100K 1M GPU Compute Ceiling (Peak FLOPS / FLOPs_per_sample) Network BW Ceiling Ridge Point CCR* = Peak_FLOPS / Network_BW Communication Bound Small batch, network bottleneck Compute Bound Large batch, GPU bottleneck Small model, batch=1 Medium model Near optimal Large batch training Ridge Point Formula CCR* = Peak_TFLOPS / Network_GB/s e.g., 312 TFLOPS / 900 GB/s ≈ 350
Computing the Ridge Point

The ridge point CCR* tells you the minimum computation-to-communication ratio needed to fully utilize your GPUs:

$$\text{CCR}^* = \frac{\text{Peak GPU FLOPS}}{\text{Aggregate Network Bandwidth}}$$

Example (8× H100 node):
• Peak FP16: 8 × 1979 TFLOPS = 15.8 PFLOPS = 15.8 × 10¹⁵ FLOPS
• NVLink bandwidth: 900 GB/s per GPU = 7.2 TB/s aggregate
• CCR* = 15.8 × 10¹⁵ / 7.2 × 10¹² ≈ 2,200 FLOPS/byte

You need at least CCR > 2,200 to be compute-bound on this hardware!

6.4 Multi-Level Communication Roofline

Modern GPU clusters have multiple levels of interconnect with vastly different bandwidths. This creates a staircase roofline with multiple ceilings:

Multi-Level Roofline: Intra-Node vs Inter-Node
CCR (FLOPS/byte communicated) Throughput GPU Compute Ceiling NVLink 900 GB/s InfiniBand HDR 200 GB/s 100GbE 12.5 GB/s Key Insight Different CCR thresholds for each network level!
Hierarchical clusters hit different ceilings depending on whether communication stays within a node (NVLink) or crosses nodes (IB/Ethernet). The effective ceiling depends on your parallelism strategy and which links are used.

6.5 Scaling Efficiency from the Roofline

The communication roofline directly tells us the scaling efficiency of distributed training. If you're communication-bound:

$$\text{Efficiency} = \frac{\text{Achieved Throughput}}{\text{Peak Throughput}} = \frac{\text{CCR} \cdot \text{BW}}{\text{Peak FLOPS} / \text{FLOPs}_{\text{sample}}}$$

When $\text{CCR} < \text{CCR}^*$, efficiency scales linearly with batch size (since CCR ∝ batch size for fixed model). When $\text{CCR} > \text{CCR}^*$, you've hit the compute ceiling—efficiency is 100% and won't improve with larger batches.

Scaling Efficiency vs Batch Size
Global Batch Size (log scale) Scaling Efficiency (%) 0% 50% 100% 64 256 1K 4K 16K 64K Critical Batch Size (≈1K for this config) Comm-bound Efficiency ∝ batch Compute-bound ~100% efficiency

6.6 Using the Roofline for Optimization

The communication roofline provides a systematic framework for optimization decisions:

Current Position Diagnosis Optimization Strategy
Far below comm ceiling
(CCR ≪ CCR*)
Severely communication-bound • Increase batch size
• Gradient compression
• Better interconnect
• Reduce communication volume
On comm ceiling
(CCR < CCR*)
Communication-bound, scaling linearly • Increase batch size (best ROI)
• Overlap comm with compute
• Use hierarchy-aware algorithms
Near ridge
(CCR ≈ CCR*)
Transitioning to compute-bound • Small batch increase helps
• Focus on compute optimizations
• Best scaling sweet spot
On compute ceiling
(CCR > CCR*)
Compute-bound (ideal!) • GPU kernel optimizations
• Mixed precision training
• Operator fusion
• Better batch/GPU utilization
Python roofline_analysis.py
def analyze_roofline_position(
    model_flops_per_sample: float,
    model_params: float,
    batch_size: int,
    gpu_peak_tflops: float,
    network_bw_gbps: float,
    dtype_bytes: int = 2,
    comm_factor: float = 2.0,  # AllReduce ≈ 2x model size
):
    """
    Determine where a workload sits on the communication roofline.
    
    Returns analysis with optimization recommendations.
    """
    # Calculate CCR
    total_flops = model_flops_per_sample * batch_size
    comm_bytes = comm_factor * model_params * dtype_bytes
    ccr = total_flops / comm_bytes
    
    # Calculate ridge point
    ccr_star = (gpu_peak_tflops * 1e12) / (network_bw_gbps * 1e9)
    
    # Theoretical throughput limits
    compute_limit = (gpu_peak_tflops * 1e12) / model_flops_per_sample  # samples/s
    comm_limit = (network_bw_gbps * 1e9) / (comm_bytes / batch_size)  # samples/s
    
    # Effective throughput (min of both limits)
    effective_throughput = min(compute_limit, comm_limit)
    efficiency = effective_throughput / compute_limit * 100
    
    # Determine regime
    if ccr < ccr_star * 0.2:
        regime = "SEVERELY_COMM_BOUND"
        recommendation = "Increase batch size by 4x+ or use gradient compression"
    elif ccr < ccr_star * 0.8:
        regime = "COMM_BOUND"
        recommendation = "Increase batch size or overlap communication"
    elif ccr < ccr_star * 1.2:
        regime = "NEAR_RIDGE"
        recommendation = "Near optimal! Small batch increase may help"
    else:
        regime = "COMPUTE_BOUND"
        recommendation = "Excellent! Focus on GPU optimizations"
    
    return {
        "ccr": ccr,
        "ccr_star": ccr_star,
        "regime": regime,
        "efficiency": efficiency,
        "recommendation": recommendation,
        "critical_batch": batch_size * (ccr_star / ccr),  # Batch to reach ridge
    }

# Example analysis
result = analyze_roofline_position(
    model_flops_per_sample=6 * 175e9 * 2048,  # GPT-3 175B, seq=2048
    model_params=175e9,
    batch_size=32,
    gpu_peak_tflops=1979,  # H100 FP16
    network_bw_gbps=900,  # NVLink-4
)

print(f"CCR: {result['ccr']:.0f} FLOPS/byte")
print(f"Ridge CCR*: {result['ccr_star']:.0f} FLOPS/byte")
print(f"Regime: {result['regime']}")
print(f"Efficiency: {result['efficiency']:.1f}%")
print(f"Recommendation: {result['recommendation']}")

6.7 Practical Roofline Numbers

Here are typical ridge point values for common GPU configurations:

Configuration Peak TFLOPS (FP16) Network BW CCR* (Ridge) Min Batch for 90% Eff
Single H100 (NVSwitch) 1,979 900 GB/s ~2,200 ~16-32 (model dependent)
8× H100 DGX (intra-node) 15,832 7.2 TB/s ~2,200 ~128-256
Multi-node H100 (IB NDR) 15,832 × N 400 GB/s/node ~40,000 ~2,000+
8× A100 (intra-node) 2,496 4.8 TB/s ~520 ~64-128
Cloud GPUs (Ethernet) varies 25-100 Gb/s ~200,000+ Very large!
The Cloud Training Challenge

Notice the dramatic difference in CCR* between on-prem DGX clusters (CCR* ≈ 2,000) and cloud instances with Ethernet (CCR* ≈ 200,000). This is why:

• Cloud training often requires 10-100× larger batch sizes for efficiency
• Gradient compression and async methods are essential for cloud
• Many teams use cloud for development but on-prem for production training

The communication roofline makes this tradeoff quantitatively clear!

7. Data Parallelism: The Foundation

Data Parallelism (DP) is the simplest and most widely used distributed training strategy. Each GPU holds a complete copy of the model, processes different data samples, and synchronizes gradients after backward pass. Let's explore its mechanics, optimizations, and PyTorch's sophisticated DDP implementation.

7.1 Basic Data Parallelism

Data Parallel Training Flow
1. Data Distribution Global Batch (B samples) B/P samples B/P samples B/P samples 2. Parallel Forward/Backward GPU 0 Model Copy ∇W₀ GPU 1 Model Copy ∇W₁ GPU 2 Model Copy ∇W₂ 3. Gradient Synchronization (AllReduce) AllReduce: ∇W = (∇W₀ + ∇W₁ + ∇W₂) / P Same ∇W Same ∇W Same ∇W 4. Identical Weight Updates W ← W - η·∇W (same update on all GPUs → models stay synchronized)

The key insight: since all GPUs apply the same averaged gradient to identical model copies, the models remain synchronized without explicit parameter broadcasting.

Python (PyTorch) basic_ddp.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp(rank, world_size):
    """Initialize distributed process group."""
    dist.init_process_group(
        backend="nccl",  # NVIDIA Collective Communication Library
        init_method="env://",
        world_size=world_size,
        rank=rank,
    )
    torch.cuda.set_device(rank)

def train_ddp(rank, world_size, model, dataloader, optimizer):
    setup_ddp(rank, world_size)
    
    # Wrap model with DDP
    model = model.to(rank)
    model = DDP(model, device_ids=[rank])
    
    for batch in dataloader:
        inputs, labels = batch
        inputs, labels = inputs.to(rank), labels.to(rank)
        
        optimizer.zero_grad()
        
        # Forward pass (independent per GPU)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass (DDP hooks trigger AllReduce)
        loss.backward()  # ← Gradients synchronized here!
        
        # Optimizer step (identical on all GPUs)
        optimizer.step()
    
    dist.destroy_process_group()

# Launch with: torchrun --nproc_per_node=8 basic_ddp.py

7.2 PyTorch DDP: Under the Hood

PyTorch's DistributedDataParallel is highly optimized. It doesn't wait for the full backward pass to complete—instead, it overlaps gradient communication with computation using gradient bucketing.

DDP Gradient Bucketing and Overlap
Backward Pass Timeline (with Gradient Bucketing) Time → Naive: Full Backward (compute all gradients) AllReduce ALL gradients T₁ DDP: Layer N grad Layer N-1 Layer N-2 ... Layer 1 Bucket 3 AllReduce Bucket 2 AllReduce Bucket 1 AllReduce T₂ Speedup from Overlap T₂ < T₁ (often 20-40% faster) Communication hidden behind compute! How Buckets Work: Bucket 1: Layers 1-10 grads Bucket 2: Layers 11-20 grads Bucket 3: Layers 21-30 grads
DDP groups gradients into buckets (default 25MB each). As soon as a bucket is full (all its gradients computed), the AllReduce for that bucket starts immediately while the backward pass continues computing other gradients.

7.3 DDP Configuration Options

Python (PyTorch) ddp_advanced.py
from torch.nn.parallel import DistributedDataParallel as DDP

# Advanced DDP configuration
model = DDP(
    model,
    device_ids=[local_rank],
    
    # Bucket size (bytes) - smaller = more overlap, more latency overhead
    bucket_cap_mb=25,  # Default: 25 MB per bucket
    
    # Find unused parameters (for models with conditional execution)
    find_unused_parameters=False,  # Set True if some params don't get gradients
    
    # Gradient as bucket view (memory optimization)
    gradient_as_bucket_view=True,  # Reduces memory by 1 copy
    
    # Static graph optimization (PyTorch 1.11+)
    static_graph=True,  # Enable if computation graph doesn't change
)

# Tune bucket size based on your model and network
# - Smaller buckets: Better overlap, more latency overhead
# - Larger buckets: Less overhead, less overlap opportunity
# - Rule of thumb: bucket_size > α/β (crossover point)

7.4 Communication Volume Analysis

Let's analyze the communication requirements for DDP precisely:

$$\text{Comm Volume per Iteration} = 2 \cdot N \cdot \text{sizeof(dtype)}$$

The factor of 2 comes from the ring AllReduce algorithm: each gradient element is sent once during ReduceScatter and once during AllGather (approximately).

Model Parameters (N) FP32 Comm FP16/BF16 Comm Time @ 200 GB/s
ResNet-50 25M 200 MB 100 MB ~0.5 ms
BERT-Large 340M 2.7 GB 1.4 GB ~7 ms
GPT-2 XL 1.5B 12 GB 6 GB ~30 ms
Llama-7B 7B 56 GB 28 GB ~140 ms
Llama-70B 70B 560 GB 280 GB ~1.4 s
DDP Memory Requirement

DDP requires each GPU to hold:

Model parameters: N × sizeof(dtype)
Gradients: N × sizeof(dtype)
Optimizer states: 2N (Adam: m, v) or more

Total for FP32 Adam: ~16N bytes per GPU
For mixed precision: ~18-20N bytes (FP16 params + FP32 master weights + FP32 optimizer)

A 7B model needs ~140 GB per GPU for full DDP training—exceeds a single 80GB H100!

7.5 Gradient Compression

For bandwidth-limited scenarios, gradient compression can reduce communication volume at the cost of some compute overhead:

Gradient Compression Techniques
Quantization FP32 → FP16: 2× reduction FP32 → INT8: 4× reduction FP32 → INT4: 8× reduction ✓ Simple, low overhead Sparsification Top-K: Keep K largest grads Random-K: Random subset Threshold: |g| > τ only ⚠ Needs error feedback Low-Rank (PowerSGD) G ≈ P·Q^T (rank-r approx) Comm: P ∈ R^(m×r), Q ∈ R^(n×r) 10-100× compression possible ◐ Higher compute cost Compression Comparison Method Compression Compute Cost Convergence FP16 Quantize Very Low Minimal impact Top-1% 100× Medium (sorting) ~5-10% more iters
Python (PyTorch) gradient_compression.py
import torch
import torch.distributed as dist

class TopKCompressor:
    """Top-K gradient sparsification with error feedback."""
    
    def __init__(self, ratio=0.01):
        self.ratio = ratio
        self.error_feedback = {}  # Accumulated residuals
    
    def compress(self, name, grad):
        # Add error feedback from previous iteration
        if name in self.error_feedback:
            grad = grad + self.error_feedback[name]
        
        # Select Top-K elements
        k = max(1, int(grad.numel() * self.ratio))
        values, indices = torch.topk(grad.abs().flatten(), k)
        
        # Store residual for next iteration
        mask = torch.zeros_like(grad.flatten())
        mask[indices] = 1
        self.error_feedback[name] = grad.flatten() * (1 - mask)
        self.error_feedback[name] = self.error_feedback[name].view_as(grad)
        
        # Return sparse representation
        return indices, grad.flatten()[indices]
    
    def decompress(self, indices, values, shape):
        grad = torch.zeros(shape, device=values.device).flatten()
        grad[indices] = values
        return grad.view(shape)

# Usage in training loop
compressor = TopKCompressor(ratio=0.01)  # Keep top 1%

for name, param in model.named_parameters():
    if param.grad is not None:
        idx, vals = compressor.compress(name, param.grad)
        
        # AllGather compressed gradients (much smaller!)
        all_idx = [torch.zeros_like(idx) for _ in range(world_size)]
        all_vals = [torch.zeros_like(vals) for _ in range(world_size)]
        dist.all_gather(all_idx, idx)
        dist.all_gather(all_vals, vals)
        
        # Decompress and average
        param.grad.zero_()
        for i, v in zip(all_idx, all_vals):
            param.grad.flatten()[i] += v / world_size

7.6 DDP Best Practices

DDP Optimization Checklist

1. Use NCCL backend for GPU training (fastest for NVIDIA GPUs)

2. Pin memory in DataLoader:
DataLoader(..., pin_memory=True, num_workers=4)

3. Use DistributedSampler:
sampler = DistributedSampler(dataset, shuffle=True)

4. Tune bucket size based on model and interconnect

5. Enable gradient_as_bucket_view (saves memory)

6. Use static_graph=True if computation graph is fixed

7. Consider mixed precision: Halves communication volume!

7.7 When DDP Isn't Enough

DDP has a fundamental limitation: each GPU must hold the entire model. This becomes impossible for large models:

DDP Memory Wall
Memory per GPU (GB) Model Size (Billions of Parameters) 80 GB (H100) 40 GB (A100) DDP (16N bytes) 0 1B 3B 7B 13B 30B ~5B limit (40GB) ~10B limit (80GB) Solution: FSDP/ZeRO Shard model across GPUs!

For models larger than ~5-10B parameters (depending on GPU memory), we need model parallelism strategies that distribute the model itself across GPUs. The next section covers FSDP/ZeRO, which elegantly extends DDP with parameter sharding.

8. FSDP and ZeRO: Sharded Data Parallelism

ZeRO (Zero Redundancy Optimizer) from DeepSpeed and FSDP (Fully Sharded Data Parallel) from PyTorch solve DDP's memory problem by sharding model state across GPUs rather than replicating everything. This trades communication for memory, enabling training of models that don't fit on a single GPU.

8.1 The Redundancy Problem

In standard DDP, every GPU holds identical copies of:

Memory Redundancy in DDP vs ZeRO
DDP: Full Replication GPU 0 Params (all) Grads (all) Opt State (all) GPU 1 Params (all) Grads (all) Opt State (all) GPU 2 Params (all) Grads (all) Opt State (all) Total: 3× full model state (redundant!) ZeRO-3/FSDP: Sharded GPU 0 P₀ G₀ O₀ GPU 1 P₁ G₁ O₁ GPU 2 P₂ G₂ O₂ Total: 1× full model state (no redundancy!) Memory per GPU Comparison (7B model, FP16 + FP32 optimizer) DDP: ~140 GB (impossible!) ZeRO-3: ~18 GB/GPU (fits easily!)

8.2 ZeRO Stages Explained

ZeRO introduces three progressive stages of sharding, each reducing memory further at the cost of more communication:

ZeRO Optimization Stages
Stage What's Sharded Memory per GPU Communication DDP (baseline) Nothing 4N + 2N + 12N = 18N bytes 2N (AllReduce) ZeRO-1 Optimizer states only 4N + 2N + 12N/P 2N (same as DDP) ZeRO-2 Optimizer + Gradients 4N + 2N/P + 12N/P 2N (ReduceScatter) ZeRO-3 / FSDP Everything (Opt + Grad + Params) (4N + 2N + 12N) / P ~3N per layer Memory Breakdown (Mixed Precision: FP16 params + FP32 optimizer) DDP: = 140 GB (14B + 14B + 112B) ZeRO-1: = 42 GB (14B + 14B + 14B) ZeRO-2: = 30 GB (14B + 1.75B + 14B) ZeRO-3: = 18 GB (1.75B + 1.75B + 14B) Parameters (2N bytes FP16) Gradients (2N bytes FP16) Optimizer (12N bytes: FP32 master + Adam m,v)

8.3 FSDP/ZeRO-3 Execution Flow

ZeRO-3/FSDP requires careful orchestration of parameter gathering before computation and gradient scattering after:

FSDP Forward + Backward Flow
Forward Pass AllGather W₁ Layer 1 fwd Free W₁ AllGather W₂ Layer 2 fwd Free W₂ AllGather W₃ Layer 3 fwd Free W₃ Key: Each layer's params gathered just-in-time, then freed to save memory Backward Pass AllGather W₃ Layer 3 bwd ReduceScatter ∇W₃ AllGather W₂ Layer 2 bwd ReduceScatter ∇W₂ AllGather W₁ Layer 1 bwd ReduceScatter ∇W₁ Key: Gradients are reduce-scattered immediately (each GPU keeps only its shard) FSDP Communication per Layer Forward: 1× AllGather | Backward: 1× AllGather + 1× ReduceScatter = ~3× parameter size

8.4 Communication Cost Analysis

FSDP trades memory for communication. Let's analyze the cost:

$$\text{FSDP Comm per Iteration} = \sum_{\text{layers}} \left( \underbrace{W_i}_{\text{fwd gather}} + \underbrace{W_i}_{\text{bwd gather}} + \underbrace{W_i}_{\text{grad scatter}} \right) = 3N$$

Compared to DDP's 2N, FSDP has 1.5× more communication. However, the memory savings often outweigh this cost, and communication can be overlapped with compute.

Strategy Memory per GPU Communication Best For
DDP ~18N bytes 2N bytes Small models that fit in GPU memory
ZeRO-1 ~6N + 12N/P bytes 2N bytes (same) Medium models, free memory savings
ZeRO-2 ~4N + 14N/P bytes 2N bytes (same) Larger models needing gradient sharding
ZeRO-3/FSDP ~18N/P bytes 3N bytes (+50%) Very large models; memory-constrained
Python (PyTorch) fsdp_example.py
import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# Define sharding strategy
sharding_strategy = ShardingStrategy.FULL_SHARD  # ZeRO-3 equivalent
# Options: FULL_SHARD, SHARD_GRAD_OP (ZeRO-2), NO_SHARD (DDP)

# Mixed precision policy
mixed_precision = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float16,
    buffer_dtype=torch.float16,
)

# Auto-wrap policy for transformers
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerBlock},  # Your layer class
)

# Wrap model with FSDP
model = FSDP(
    model,
    sharding_strategy=sharding_strategy,
    mixed_precision=mixed_precision,
    auto_wrap_policy=auto_wrap_policy,
    device_id=torch.cuda.current_device(),
    
    # Performance optimizations
    use_orig_params=True,  # Better compatibility with optimizers
    limit_all_gathers=True,  # Prevent OOM from too many concurrent gathers
)

# Training loop (same as regular PyTorch!)
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch).loss
    loss.backward()
    optimizer.step()

8.5 Prefetching and Overlap

Modern FSDP implementations hide communication latency by prefetching the next layer's parameters while computing the current layer:

FSDP Prefetching: Overlapping Communication
Without Prefetching Gather W₁ Compute L₁ Gather W₂ Compute L₂ Gather W₃ Compute L₃ T₁ With Prefetching W₁ W₂ W₃ Compute L₁ Compute L₂ Compute L₃ T₂ < T₁ ✓ Comm Compute
Prefetching starts gathering the next layer's parameters while the current layer is still computing. With sufficient compute time per layer, communication can be fully hidden.
Python (PyTorch) fsdp_prefetch.py
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    BackwardPrefetch,
    ForwardPrefetch,
)

# Configure prefetching
model = FSDP(
    model,
    # Forward prefetching (gather next layer during current layer's compute)
    forward_prefetch=True,
    
    # Backward prefetching strategy
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,  # or BACKWARD_POST
    # BACKWARD_PRE: Prefetch before backward (default, best for most cases)
    # BACKWARD_POST: Prefetch after backward (use if memory-tight)
)

# For maximum overlap, also enable:
# - async_op=True in NCCL operations
# - CUDA streams for concurrent execution

8.6 Hybrid Sharding (HSDP)

For multi-node training, Hybrid Sharded Data Parallel (HSDP) combines FSDP within nodes and DDP across nodes. This leverages the fast intra-node NVLink for sharding while minimizing slow inter-node communication:

Hybrid Sharding: Best of Both Worlds
Node 0 (FSDP Group) GPU 0 Shard 0 ∇Shard 0 GPU 1 Shard 1 ∇Shard 1 GPU 2 Shard 2 ∇Shard 2 GPU 3 Shard 3 ∇Shard 3 ← NVLink (fast) → Node 1 (FSDP Group) GPU 4 Shard 0 GPU 5 Shard 1 GPU 6 Shard 2 GPU 7 Shard 3 ← NVLink (fast) → IB (slow) AllReduce same shards HSDP Communication Pattern Intra-node: FSDP (AllGather/ReduceScatter) | Inter-node: DDP (AllReduce matching shards)
When to Use Each ZeRO Stage

ZeRO-1: Model fits in memory with DDP, but optimizer states are too large. Free memory savings, no extra communication.

ZeRO-2: Model fits, but gradients + optimizer don't. Small models with large batch sizes. Still no extra communication!

ZeRO-3/FSDP: Model itself doesn't fit. Essential for 10B+ parameter models. Accept 1.5× communication overhead.

HSDP: Multi-node training where inter-node bandwidth is limited. Minimizes cross-node traffic while enabling large models.

9. Tensor Parallelism: Sharding Within Layers

While FSDP shards parameters between layers (gather before compute, scatter after), Tensor Parallelism (TP) shards parameters within individual layers. Each GPU computes a portion of every matrix multiplication, enabling parallelism at the finest granularity. This technique was pioneered by Megatron-LM for training massive transformer models.

9.1 The Core Idea: Matrix Partitioning

Consider a linear layer Y = XW where X is [B×K] and W is [K×N]. We can partition this computation across GPUs in two fundamental ways:

Two Ways to Partition Matrix Multiplication
Column Parallel (Split Output Dim) X [B×K] × W₀ [K×N/2] W₁ [K×N/2] = Y₀ [B×N/2] Y₁ [B×N/2] No comm before matmul. Results on separate GPUs. Row Parallel (Split Input Dim) X₀ [B×K/2] X₁ [B×K/2] × W₀ [K/2×N] W₁ [K/2×N] = Y₀' [B×N] Y₁' [B×N] Y = Σ Yᵢ' Need AllReduce to sum partial results. Key Insight: Combine for Communication Efficiency Column-parallel → Row-parallel: Cancel AllGather and AllReduce! X (replicated) Column-Par W₁ local GeLU local Row-Par W₂ AR Only ONE AllReduce per MLP block! (Not two)

9.2 Megatron-LM Self-Attention Parallelism

In transformer self-attention, we can parallelize across the attention heads. Each GPU computes a subset of heads independently, then results are combined:

Tensor Parallel Self-Attention (TP=2)
X [B, S, H] (replicated on all GPUs) QKV Projections (Column Parallel) GPU 0 Q₀ K₀ V₀ GPU 1 Q₁ K₁ V₁ Each GPU has heads 0...(H/2-1) or (H/2)...(H-1) Attention Computation (Local, No Communication!) A₀ = softmax(Q₀K₀ᵀ/√d)V₀ A₁ = softmax(Q₁K₁ᵀ/√d)V₁ Output Projection (Row Parallel) O₀' = A₀ × W_o[: , :H/2] O₁' = A₁ × W_o[:, H/2:] AllReduce: O = O₀' + O₁' Output [B, S, H]
Self-attention with Tensor Parallelism: Only ONE AllReduce per attention block. QKV projections are column-parallel, output projection is row-parallel.

9.3 Megatron-LM MLP Parallelism

The feedforward network (MLP) in transformers follows the same pattern:

$$\text{MLP}(X) = \text{GeLU}(X W_1) W_2$$
Tensor Parallel MLP (TP=2)
X W₁ col 0 (GPU 0) W₁ col 1 (GPU 1) local GeLU₀ GeLU₁ local W₂ row 0 (GPU 0) W₂ row 1 (GPU 1) AllReduce Y = Σ Yᵢ Y W₁: Column-parallel (no input comm) → W₂: Row-parallel (AllReduce at end)

9.4 Communication Analysis

Tensor parallelism has very different communication characteristics from data parallelism:

Aspect Data Parallel (DDP) Tensor Parallel (TP)
Communication Pattern AllReduce gradients (backward only) AllReduce activations (forward + backward)
Frequency Once per iteration 2× per transformer layer (Attn + MLP)
Message Size All gradients (~N parameters) Activation tensor (B × S × H)
Bandwidth Need Can overlap with backward Critical path (compute blocked!)
Latency Sensitivity Low (large messages) High (many small messages)
TP Requires High-Bandwidth Links

Since tensor parallelism has communication in the critical path of every forward and backward pass, it requires extremely high bandwidth between GPUs. This is why:

  • TP within a node: Use NVLink (600-900 GB/s on H100)
  • TP across nodes: Generally avoided! IB is too slow.
  • Typical TP degree: 2, 4, or 8 (matching GPUs per node)

9.5 Communication Volume per Layer

For a transformer layer with tensor parallelism degree T:

$$\text{TP Comm per Layer} = 4 \times \frac{2(T-1)}{T} \times B \times S \times H$$

Where the factor of 4 comes from: 2 AllReduces (attention + MLP) × 2 for forward and backward. The $\frac{2(T-1)}{T}$ term is the AllReduce cost in the ring algorithm.

Python tp_communication.py
def tensor_parallel_comm_per_layer(
    batch_size: int,
    seq_len: int,
    hidden_dim: int,
    tp_degree: int,
    dtype_bytes: int = 2  # FP16
) -> float:
    """Calculate communication volume per transformer layer with TP."""
    
    # Activation size
    activation_size = batch_size * seq_len * hidden_dim * dtype_bytes
    
    # AllReduce cost (ring algorithm)
    allreduce_factor = 2 * (tp_degree - 1) / tp_degree
    
    # 4 AllReduces per layer (attn fwd, attn bwd, mlp fwd, mlp bwd)
    comm_volume = 4 * allreduce_factor * activation_size
    
    return comm_volume

# Example: GPT-3 175B style layer
comm_bytes = tensor_parallel_comm_per_layer(
    batch_size=1,
    seq_len=2048,
    hidden_dim=12288,
    tp_degree=8
)
print(f"Comm per layer: {comm_bytes / 1e6:.1f} MB")
# Output: ~175 MB per layer

# With 96 layers: ~16.8 GB total TP communication per iteration!

9.6 Sequence Parallelism

A refinement to tensor parallelism is Sequence Parallelism (SP), which distributes the LayerNorm and Dropout operations across the sequence dimension. This reduces activation memory further:

Sequence Parallelism: Distributing Non-Tensor-Parallel Ops
Standard Tensor Parallelism LayerNorm (replicated) Attention (tensor par) Dropout (replicated) Residual (replicated) Red = Full activation memory on each GPU TP + Sequence Parallelism LayerNorm (seq shard) AG Attention (tensor par) RS Dropout (seq shard) Residual (seq shard) Green = 1/TP activation memory per GPU Activation Memory Savings Standard TP: B×S×H per GPU TP + SP: B×S×H/TP per GPU ✓
Sequence parallelism replaces AllReduce with AllGather→compute→ReduceScatter, but keeps activations sharded across the sequence dimension in LayerNorm/Dropout.

9.7 Implementation: Megatron-Style TP

Python (PyTorch) megatron_tp.py
import torch
import torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):
    """Linear layer with column-parallel weight distribution.
    
    Input X is replicated, weight W is split along columns (output dim).
    Output Y is split along columns (each GPU has different output features).
    """
    
    def __init__(self, in_features, out_features, tp_group, gather_output=False):
        super().__init__()
        self.tp_group = tp_group
        self.tp_size = dist.get_world_size(tp_group)
        self.gather_output = gather_output
        
        # Each GPU gets out_features/tp_size columns
        assert out_features % self.tp_size == 0
        self.out_features_per_partition = out_features // self.tp_size
        
        self.weight = nn.Parameter(
            torch.empty(self.out_features_per_partition, in_features)
        )
        self.bias = nn.Parameter(torch.empty(self.out_features_per_partition))
        
    def forward(self, x):
        # Local matmul: no communication!
        output = F.linear(x, self.weight, self.bias)
        
        if self.gather_output:
            # Optionally gather to get full output
            output = _gather_along_last_dim(output, self.tp_group)
        
        return output


class RowParallelLinear(nn.Module):
    """Linear layer with row-parallel weight distribution.
    
    Input X is split along last dim, weight W is split along rows (input dim).
    Outputs are partial sums that need AllReduce.
    """
    
    def __init__(self, in_features, out_features, tp_group, input_is_parallel=True):
        super().__init__()
        self.tp_group = tp_group
        self.tp_size = dist.get_world_size(tp_group)
        self.input_is_parallel = input_is_parallel
        
        # Each GPU gets in_features/tp_size rows
        assert in_features % self.tp_size == 0
        self.in_features_per_partition = in_features // self.tp_size
        
        self.weight = nn.Parameter(
            torch.empty(out_features, self.in_features_per_partition)
        )
        self.bias = nn.Parameter(torch.empty(out_features))
        
    def forward(self, x):
        if not self.input_is_parallel:
            # Split input if it's replicated
            x = _split_along_last_dim(x, self.tp_group)
        
        # Local matmul (partial result)
        output_partial = F.linear(x, self.weight)
        
        # AllReduce to sum partial results
        dist.all_reduce(output_partial, group=self.tp_group)
        
        # Add bias after reduction (only on one rank to avoid duplication)
        if dist.get_rank(self.tp_group) == 0:
            output_partial = output_partial + self.bias
        
        return output_partial


class TensorParallelMLP(nn.Module):
    """Megatron-style tensor parallel MLP.
    
    Column-parallel first layer, row-parallel second layer.
    Only ONE AllReduce for the entire MLP!
    """
    
    def __init__(self, hidden_size, ffn_hidden_size, tp_group):
        super().__init__()
        
        # Column-parallel: split FFN hidden dim
        self.fc1 = ColumnParallelLinear(
            hidden_size, ffn_hidden_size, tp_group, gather_output=False
        )
        
        # Row-parallel: recombine to hidden_size
        self.fc2 = RowParallelLinear(
            ffn_hidden_size, hidden_size, tp_group, input_is_parallel=True
        )
        
    def forward(self, x):
        # x: [B, S, H] - replicated across TP ranks
        x = self.fc1(x)      # [B, S, FFN/TP] - split output
        x = F.gelu(x)        # Local activation
        x = self.fc2(x)      # [B, S, H] - AllReduce inside
        return x

9.8 When to Use Tensor Parallelism

TP Decision Guide

Use Tensor Parallelism when:

  • Single layers are too large for GPU memory (very wide models)
  • You have fast intra-node interconnect (NVLink, NVSwitch)
  • Model has many attention heads (divisible by TP degree)
  • Latency per step matters more than throughput

Avoid Tensor Parallelism when:

  • Training across nodes without high-speed interconnect
  • Small models where FSDP/DDP suffice
  • TP degree doesn't divide attention heads evenly

Best Practices:

  • TP=8 for 8-GPU nodes (H100, A100 DGX)
  • Combine with Pipeline Parallelism across nodes
  • Use Sequence Parallelism to reduce activation memory

10. Pipeline Parallelism: Partitioning Across Layers

Pipeline Parallelism (PP) partitions the model vertically by assigning different layers to different GPUs. Unlike tensor parallelism (which splits within layers), pipeline parallelism keeps each layer intact but distributes them across devices. This approach is ideal for spanning multiple nodes with slower inter-node connections.

10.1 The Naive Approach and the Bubble Problem

The simplest pipeline would process one micro-batch at a time through all stages. But this creates massive pipeline bubbles where most GPUs sit idle:

Naive Pipeline: The Bubble Problem
Time → GPU 0 (L1-L6) GPU 1 (L7-L12) GPU 2 (L13-L18) GPU 3 (L19-L24) F F F F B B B B Forward Backward Bubble (idle) Bubble fraction = (P-1)/P = 75% idle time with 4 stages!

With P pipeline stages and 1 micro-batch, each GPU is active only 2/P of the time (one forward, one backward). The rest is bubble time:

$$\text{Bubble Fraction} = \frac{P - 1}{P} \xrightarrow{P \to \infty} 100\%$$

10.2 GPipe: Micro-Batching to Reduce Bubbles

GPipe (Google, 2019) introduced micro-batching: split each mini-batch into M smaller micro-batches and pipeline them through the stages. This fills the pipeline and dramatically reduces bubble fraction:

GPipe Schedule (4 Stages, 4 Micro-batches)
Time → Stage 0 Stage 1 Stage 2 Stage 3 F₀ F₁ F₂ F₃ F₀ F₁ F₂ F₃ F₀ F₁ F₂ F₃ F₀ F₁ F₂ F₃ B₀ B₁ B₂ B₃ B₀ B₁ B₂ B₃ B₀ B₁ B₂ B₃ B₀ B₁ B₂ B₃ GPipe: All forwards → All backwards | Bubble = (P-1)/(M+P-1) ≈ 43% with P=4, M=4 Problem: Must store ALL M activation checkpoints until backward!

GPipe reduces bubbles but has a critical problem: activations from all M micro-batches must be stored until the backward pass begins. Memory scales as O(M × activations_per_microbatch).

10.3 1F1B: Interleaving Forward and Backward

The 1F1B (One Forward One Backward) schedule from PipeDream solves the memory problem. After the warmup phase, each stage alternates between one forward and one backward, maintaining a constant number of in-flight micro-batches:

1F1B Schedule (4 Stages, 8 Micro-batches)
Time → S0 S1 S2 S3 F₀ F₁ F₂ F₃ F₀ F₁ F₂ F₀ F₁ F₀ B₀ F₁ B₁ F₂ B₂ F₃ B₃ F₂ B₀ F₃ B₁ F₄ B₂ F₃ F₄ B₀ F₅ B₁ F₄ F₅ F₆ F₇ B₀ ... cooldown (drain) ... 1F1B Memory Advantage Peak activations in flight: GPipe: M micro-batches 1F1B: max P micro-batches ✓ Warmup Steady State (1F1B) Cooldown Bubble = (P-1)/(M+P-1) - same as GPipe, but constant memory!

10.4 Interleaved 1F1B (Virtual Stages)

Interleaved 1F1B further reduces bubbles by assigning multiple non-consecutive "virtual stages" to each GPU. With V virtual stages per GPU:

$$\text{Bubble Fraction} = \frac{P - 1}{M \cdot V + P - 1}$$
Interleaved Stages: V=2 Virtual Stages per GPU
Layer Assignment with V=2 (4 GPUs, 8 virtual stages) L1-L3 (GPU 0) L4-L6 (GPU 1) L7-L9 (GPU 2) L10-L12 (GPU 3) L13-L15 (GPU 3) L16-L18 (GPU 2) L19-L21 (GPU 1) L22-L24 (GPU 0) Forward: GPU 0 → 1 → 2 → 3 → 3 → 2 → 1 → 0 With V=2: Halves bubble time! Also reduces point-to-point message size.

10.5 Communication Patterns in Pipeline Parallelism

Unlike data parallelism (AllReduce) or tensor parallelism (AllReduce), pipeline parallelism uses point-to-point (P2P) communication:

Communication Direction Data Size
Forward Activations Stage i → Stage i+1 Hidden states B × S × H bytes
Backward Gradients Stage i+1 → Stage i Activation gradients B × S × H bytes
Pipeline P2P Communication
Stage 0 Layers 0-5 Stage 1 Layers 6-11 Stage 2 Layers 12-17 Stage 3 Layers 18-23 acts grads Send/Recv operations (not collective!) - Can use slower inter-node links effectively

10.6 Pipeline Communication Volume

Total PP communication per micro-batch per stage boundary:

$$\text{PP Comm} = 2 \times B \times S \times H \times \text{dtype\_bytes}$$

(Factor of 2: one forward send, one backward send)

Python pp_communication.py
def pipeline_comm_per_iteration(
    batch_size: int,
    seq_len: int,
    hidden_dim: int,
    num_micro_batches: int,
    num_stages: int,
    dtype_bytes: int = 2
) -> dict:
    """Calculate pipeline parallel communication."""
    
    micro_batch_size = batch_size // num_micro_batches
    
    # Activation size per micro-batch
    act_size = micro_batch_size * seq_len * hidden_dim * dtype_bytes
    
    # P2P comm per stage boundary
    # Forward: send activations
    # Backward: send gradients
    p2p_per_boundary = 2 * act_size  # fwd + bwd
    
    # Total boundaries = num_stages - 1
    # Total micro-batches = num_micro_batches
    total_p2p = p2p_per_boundary * (num_stages - 1) * num_micro_batches
    
    # Bubble analysis
    bubble_fraction = (num_stages - 1) / (num_micro_batches + num_stages - 1)
    
    return {
        "p2p_per_microbatch_boundary": p2p_per_boundary,
        "total_p2p_bytes": total_p2p,
        "bubble_fraction": bubble_fraction,
        "efficiency": 1 - bubble_fraction
    }

# Example: 8 stages, 32 micro-batches
result = pipeline_comm_per_iteration(
    batch_size=512,
    seq_len=2048,
    hidden_dim=8192,
    num_micro_batches=32,
    num_stages=8
)
print(f"Efficiency: {result['efficiency']*100:.1f}%")  # ~82%
print(f"Total P2P: {result['total_p2p_bytes']/1e9:.1f} GB")

10.7 Implementation with PyTorch

Python (PyTorch) pipeline_stage.py
import torch
import torch.distributed as dist

class PipelineStage:
    """Simple 1F1B pipeline stage implementation."""
    
    def __init__(self, module, stage_id, num_stages, process_group):
        self.module = module
        self.stage_id = stage_id
        self.num_stages = num_stages
        self.group = process_group
        
        # Determine neighbors
        self.prev_rank = stage_id - 1 if stage_id > 0 else None
        self.next_rank = stage_id + 1 if stage_id < num_stages - 1 else None
        
    def recv_forward(self, shape, dtype):
        """Receive activations from previous stage."""
        if self.prev_rank is None:
            return None
        
        tensor = torch.empty(shape, dtype=dtype, device='cuda')
        dist.recv(tensor, src=self.prev_rank, group=self.group)
        return tensor.requires_grad_()
    
    def send_forward(self, tensor):
        """Send activations to next stage."""
        if self.next_rank is None:
            return
        dist.send(tensor, dst=self.next_rank, group=self.group)
    
    def recv_backward(self, shape, dtype):
        """Receive gradients from next stage."""
        if self.next_rank is None:
            return None
        
        tensor = torch.empty(shape, dtype=dtype, device='cuda')
        dist.recv(tensor, src=self.next_rank, group=self.group)
        return tensor
    
    def send_backward(self, tensor):
        """Send gradients to previous stage."""
        if self.prev_rank is None:
            return
        dist.send(tensor, dst=self.prev_rank, group=self.group)

    def forward_step(self, input_tensor):
        """Execute one forward micro-batch."""
        output = self.module(input_tensor)
        self.send_forward(output.detach())
        return output
    
    def backward_step(self, output_tensor, output_grad):
        """Execute one backward micro-batch."""
        torch.autograd.backward(output_tensor, output_grad)
        
        if output_tensor.grad is not None:
            self.send_backward(output_tensor.grad)


def schedule_1f1b(stage, num_micro_batches, input_tensors):
    """1F1B schedule for one pipeline stage."""
    
    num_warmup = stage.num_stages - stage.stage_id - 1
    num_microbatches_remaining = num_micro_batches - num_warmup
    
    # Store outputs for backward
    outputs = []
    
    # Warmup: only forward passes
    for i in range(num_warmup):
        inp = stage.recv_forward(...) if stage.prev_rank else input_tensors[i]
        out = stage.forward_step(inp)
        outputs.append((inp, out))
    
    # Steady state: 1F1B
    for i in range(num_microbatches_remaining):
        # Forward
        inp = stage.recv_forward(...) if stage.prev_rank else input_tensors[num_warmup + i]
        out = stage.forward_step(inp)
        outputs.append((inp, out))
        
        # Backward (for micro-batch i)
        old_inp, old_out = outputs.pop(0)
        grad = stage.recv_backward(...) if stage.next_rank else None
        stage.backward_step(old_out, grad)
    
    # Cooldown: only backward passes
    for _ in range(num_warmup):
        old_inp, old_out = outputs.pop(0)
        grad = stage.recv_backward(...) if stage.next_rank else None
        stage.backward_step(old_out, grad)

10.8 When to Use Pipeline Parallelism

Pipeline Parallelism Decision Guide

Ideal for:

  • Multi-node training where inter-node bandwidth is limited
  • Very deep models (many layers to distribute)
  • When combined with TP within nodes (3D parallelism)
  • Large batch training where many micro-batches hide bubbles

Challenges:

  • Pipeline bubbles reduce efficiency (need M >> P)
  • Complex scheduling logic
  • Load balancing across stages (unequal layer costs)
  • Memory imbalance (first/last stages often need more)

Best Practices:

  • Use micro-batches M ≥ 4×P for reasonable efficiency
  • Consider interleaved schedules (V > 1) for lower bubbles
  • Place PP across nodes, TP within nodes
  • Balance layers to equalize compute per stage

11. 3D Parallelism and Mixture of Experts

Training the largest models (100B+ parameters) requires combining multiple parallelism strategies. 3D Parallelism combines Data, Tensor, and Pipeline parallelism to scale across thousands of GPUs. Mixture of Experts (MoE) adds another dimension by dynamically routing tokens to specialized sub-networks.

11.1 The 3D Parallelism Framework

Each parallelism strategy has different strengths and communication patterns. 3D parallelism assigns each dimension to the appropriate network topology:

3D Parallelism: Combining DP × TP × PP
Data Parallel (DP=4) Pipeline (PP=4) Tensor Parallel (TP=8) GPU [dp,tp,pp] rank Parallelism Dimensions Tensor Parallel (TP=8) • Within a single node (NVLink) • AllReduce activations (critical path) • Highest bandwidth requirement Pipeline Parallel (PP=4) • Across nodes (InfiniBand) • P2P send/recv (not collective) • Tolerates higher latency Data Parallel (DP=4) • Across pipeline replicas • AllReduce gradients (overlapped) • Can use slower network Total GPUs = TP × PP × DP 8 × 4 × 4 = 128 GPUs Communication Hierarchy (Bandwidth Required) TP: 600+ GB/s (NVLink) PP: 50-200 GB/s (IB) DP: 25-100 GB/s (IB) ← decreasing

11.2 Process Group Configuration

In 3D parallelism, GPUs are organized into process groups for each parallelism dimension. A GPU participates in three groups simultaneously:

Python (PyTorch) process_groups_3d.py
import torch.distributed as dist

def initialize_3d_parallelism(
    tp_size: int,
    pp_size: int, 
    dp_size: int
):
    """Initialize process groups for 3D parallelism.
    
    GPU layout (example: TP=2, PP=2, DP=2 = 8 GPUs):
    
    Node 0: GPU 0,1 (TP group 0, PP stage 0)
    Node 1: GPU 2,3 (TP group 1, PP stage 1)
    Node 2: GPU 4,5 (TP group 2, PP stage 0)  # DP replica
    Node 3: GPU 6,7 (TP group 3, PP stage 1)  # DP replica
    """
    
    world_size = tp_size * pp_size * dp_size
    rank = dist.get_rank()
    
    # Calculate coordinates in 3D grid
    tp_rank = rank % tp_size
    pp_rank = (rank // tp_size) % pp_size
    dp_rank = rank // (tp_size * pp_size)
    
    # Create Tensor Parallel groups
    # GPUs in same TP group: same pp_rank, same dp_rank
    tp_groups = []
    for dp in range(dp_size):
        for pp in range(pp_size):
            ranks = [
                dp * (tp_size * pp_size) + pp * tp_size + tp
                for tp in range(tp_size)
            ]
            group = dist.new_group(ranks)
            if rank in ranks:
                tp_group = group
    
    # Create Pipeline Parallel groups
    # GPUs in same PP group: same tp_rank, same dp_rank
    pp_groups = []
    for dp in range(dp_size):
        for tp in range(tp_size):
            ranks = [
                dp * (tp_size * pp_size) + pp * tp_size + tp
                for pp in range(pp_size)
            ]
            group = dist.new_group(ranks)
            if rank in ranks:
                pp_group = group
    
    # Create Data Parallel groups
    # GPUs in same DP group: same tp_rank, same pp_rank
    for pp in range(pp_size):
        for tp in range(tp_size):
            ranks = [
                dp * (tp_size * pp_size) + pp * tp_size + tp
                for dp in range(dp_size)
            ]
            group = dist.new_group(ranks)
            if rank in ranks:
                dp_group = group
    
    return {
        'tp_group': tp_group,
        'pp_group': pp_group,
        'dp_group': dp_group,
        'tp_rank': tp_rank,
        'pp_rank': pp_rank,
        'dp_rank': dp_rank,
    }

# Example: 128 GPUs with TP=8, PP=4, DP=4
groups = initialize_3d_parallelism(tp_size=8, pp_size=4, dp_size=4)
print(f"GPU {dist.get_rank()}: TP={groups['tp_rank']}, PP={groups['pp_rank']}, DP={groups['dp_rank']}")

11.3 Communication Analysis for 3D Parallelism

Let's analyze the total communication volume for one training iteration:

Dimension Operation Volume per Iteration When
TP AllReduce activations 4 × L × B × S × H × 2(T-1)/T Every layer fwd+bwd
PP P2P send/recv 2 × M × (P-1) × B/M × S × H Per micro-batch boundary
DP AllReduce gradients 2 × N/P × 2(D-1)/D Once at end (overlapped)
Key Insight: Communication Overlap

In well-optimized 3D parallelism:

  • TP communication is on the critical path (cannot hide)
  • PP communication is partially hidden during micro-batch processing
  • DP communication can be fully overlapped with backward pass

This is why TP needs NVLink, PP can use IB, and DP can tolerate the slowest links.

11.4 Mixture of Experts (MoE)

Mixture of Experts is an orthogonal scaling technique that increases model capacity without proportionally increasing compute. Each token is routed to only a subset of "expert" sub-networks:

Mixture of Experts Layer
MoE Layer: Sparse Activation Input X [B×S, H] Router Network G(x) = softmax(x·Wg) Top-K routing (typically K=1 or 2) g₀=0.7 g₁=0.3 g₂=0.8 g₃=0.2 Expert 0 FFN₀(x) Expert 1 FFN₁(x) Expert 2 FFN₂(x) Expert 3 FFN₃(x) GPU 0 GPU 1 GPU 2 GPU 3 Output Y = Σᵢ gᵢ · Expertᵢ(x)

11.5 Expert Parallelism Communication

MoE introduces a unique communication pattern: All-to-All to route tokens to their assigned experts, then All-to-All again to return results:

MoE All-to-All Communication
Before: Tokens on source GPUs GPU 0 GPU 1 GPU 2 Colors = destination expert All-to-All After: Tokens on expert GPUs Expert 0 Expert 1 Expert 2 Each GPU processes its expert's tokens MoE Communication Volume (per layer) Forward: All-to-All dispatch + All-to-All combine = 2 × B × S × H × (E-1)/E Backward: Same pattern → Total 4 × B × S × H × (E-1)/E per MoE layer

11.6 Expert Parallel vs Other Dimensions

Aspect Tensor Parallel Expert Parallel
What's distributed Single layer's parameters Different expert FFNs
Operation AllReduce (sum partials) All-to-All (route tokens)
Message pattern Everyone sends to everyone (same data) Everyone sends different data to everyone
Compute All GPUs compute every token Each GPU computes subset of tokens
Load balance Perfect (same compute) Can be imbalanced (routing-dependent)

11.7 Load Balancing in MoE

A critical challenge in MoE is load balancing. If most tokens route to the same experts, some GPUs are overloaded while others sit idle:

Python (PyTorch) moe_load_balance.py
import torch
import torch.nn.functional as F

def load_balancing_loss(router_logits, num_experts, top_k=2):
    """Auxiliary loss to encourage balanced expert usage.
    
    From Switch Transformer / GShard papers.
    """
    # router_logits: [batch_size * seq_len, num_experts]
    
    # Get routing probabilities
    routing_probs = F.softmax(router_logits, dim=-1)
    
    # Get top-k expert assignments
    _, top_k_indices = torch.topk(router_logits, k=top_k, dim=-1)
    
    # Create one-hot masks for selected experts
    expert_mask = torch.zeros_like(routing_probs)
    expert_mask.scatter_(-1, top_k_indices, 1.0)
    
    # Fraction of tokens routed to each expert
    tokens_per_expert = expert_mask.mean(dim=0)  # [num_experts]
    
    # Average routing probability per expert
    router_prob_per_expert = routing_probs.mean(dim=0)  # [num_experts]
    
    # Load balancing loss: encourage uniform distribution
    # Minimize product of (fraction routed) × (avg probability)
    aux_loss = num_experts * (tokens_per_expert * router_prob_per_expert).sum()
    
    return aux_loss


class MoELayer(torch.nn.Module):
    def __init__(self, hidden_size, ffn_size, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Router network
        self.router = torch.nn.Linear(hidden_size, num_experts, bias=False)
        
        # Expert FFN networks
        self.experts = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(hidden_size, ffn_size),
                torch.nn.GELU(),
                torch.nn.Linear(ffn_size, hidden_size)
            )
            for _ in range(num_experts)
        ])
        
    def forward(self, x):
        batch_seq, hidden = x.shape
        
        # Compute routing scores
        router_logits = self.router(x)  # [B*S, E]
        
        # Get top-k experts and weights
        routing_weights, selected_experts = torch.topk(
            F.softmax(router_logits, dim=-1), 
            self.top_k, 
            dim=-1
        )
        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
        
        # Compute expert outputs (simplified - real impl uses All-to-All)
        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            expert_mask = (selected_experts == i).any(dim=-1)
            if expert_mask.any():
                expert_input = x[expert_mask]
                expert_output = expert(expert_input)
                
                # Weight by routing probability
                weights = routing_weights[expert_mask]
                weight_mask = (selected_experts[expert_mask] == i)
                expert_weight = (weights * weight_mask).sum(dim=-1, keepdim=True)
                
                output[expert_mask] += expert_weight * expert_output
        
        # Compute auxiliary loss for load balancing
        self.aux_loss = load_balancing_loss(router_logits, self.num_experts, self.top_k)
        
        return output

11.8 Combining MoE with 3D Parallelism

Modern systems like Megatron-LM and DeepSpeed combine MoE with 3D parallelism, creating 4D+ parallelism:

Full Parallelism Stack for Trillion-Parameter Models
Data Parallelism (DP) - Replicate entire model across groups Pipeline Parallelism (PP) - Distribute layers across stages Tensor Parallelism (TP) - Split attention/MLP within layers Expert Parallelism (EP) - Distribute MoE experts across GPUs Example: DP=8 × PP=8 × TP=8 × EP=8 = 4096 GPUs for 1T+ parameter MoE
MoE Scaling Properties

Advantages:

  • Parameter efficiency: 8× more parameters with only 2× compute (top-2 routing)
  • Specialization: Experts can learn different aspects of data
  • Scalability: Add experts without increasing per-token compute

Challenges:

  • All-to-All communication: Expensive on multi-node systems
  • Load imbalance: Requires auxiliary losses and capacity factors
  • Training instability: Experts can "collapse" if not careful
  • Inference complexity: Batching across experts is tricky

12. Case Studies: How the Giants Train

Let's examine how state-of-the-art models configure their distributed training systems. These case studies illustrate how the parallelism strategies we've covered combine in practice.

12.1 GPT-3 (175B Parameters)

OpenAI's GPT-3 (2020) was trained on a cluster of V100 GPUs using a combination of model and data parallelism:

GPT-3 Training Configuration
GPT-3 175B Specifications • Parameters: 175 billion • Layers: 96 transformer blocks • Hidden dim: 12,288 | Heads: 96 • Context: 2,048 tokens • Training tokens: 300B Parallelism Configuration • Tensor Parallel: TP = 8 (within node) • Pipeline Parallel: PP = ~8 stages • Data Parallel: DP = varies • Micro-batches: ~32 per iteration • Total GPUs: ~1,000+ V100s Communication Analysis TP (per layer): 4 × AllReduce × B×S×H ≈ 4 × 2×2048×12288 × 2 bytes ≈ 400 MB/layer PP (per μbatch): P2P activations ≈ B×S×H × 2 bytes ≈ 50 MB per stage boundary DP (per iter): AllReduce 175B params / PP ≈ 22B params ≈ 44 GB (overlapped with backward)

12.2 LLaMA / LLaMA-2 (7B - 70B)

Meta's LLaMA models demonstrate efficient training with modern optimizations:

Model Parameters TP PP DP GPUs Training Time
LLaMA-7B 6.7B 1 1 Variable ~1,000 A100 ~21 days
LLaMA-13B 13B 2 1 Variable ~1,000 A100 ~29 days
LLaMA-33B 32.5B 4 1 Variable ~1,000 A100 ~53 days
LLaMA-65B 65.2B 8 1 Variable ~2,000 A100 ~21 days
LLaMA-2 70B 70B 8 1 Variable ~2,000 A100 ~1.7M GPU-hrs
LLaMA Efficiency Tricks
  • No PP needed for most variants: Models fit with TP + FSDP
  • Grouped-Query Attention (GQA): Reduces KV cache memory in LLaMA-2
  • RMSNorm + SwiGLU: Efficient normalizations and activations
  • Rotary embeddings: No absolute position embeddings to store

12.3 Mixtral 8x7B (MoE)

Mistral's Mixtral demonstrates efficient MoE training with expert parallelism:

Mixtral 8x7B Architecture
Mixtral 8x7B Configuration • 8 experts, each ~7B params (46.7B total) • Top-2 routing → 12.9B active params/token • 32 layers, 4096 hidden, 32 heads Parallelism Strategy • Expert Parallel: EP = 8 (one expert/GPU) • Tensor Parallel: TP = 1-2 per expert group • Data Parallel: FSDP across replicas MoE Communication Pattern (per layer) 1. All-to-All dispatch: Tokens → Experts (~B×S×H bytes) 2. Expert compute: Each GPU processes assigned tokens locally 3. All-to-All combine: Results → Original positions 4. Total: ~4×B×S×H per MoE layer (fwd+bwd)

12.4 DeepSeek-V2 (236B MoE)

DeepSeek's V2 model showcases advanced MoE techniques with innovative communication optimization:

Configuration deepseek_v2_config.yaml
# DeepSeek-V2 236B Configuration
model:
  total_params: 236B
  active_params: 21B  # Per token
  
  architecture:
    layers: 60
    hidden_dim: 5120
    num_experts: 160
    experts_per_token: 6  # Top-6 routing
    shared_experts: 2     # Always-active experts
    
  innovations:
    - Multi-head Latent Attention (MLA)  # Reduces KV cache
    - DeepSeekMoE with shared experts
    - Fine-grained expert segmentation

training:
  parallelism:
    expert_parallel: 8-16
    tensor_parallel: 1-4
    pipeline_parallel: 1
    data_parallel: FSDP
    
  hardware:
    gpus: ~2000 H800
    interconnect: NVLink + IB
    
  efficiency:
    tokens_trained: 8.1T
    training_cost: ~$5.5M  # Remarkably efficient!

12.5 Practical Decision Framework

Use this decision tree to choose your parallelism strategy:

Parallelism Strategy Decision Tree
Model Size? < 10B Use DDP or FSDP Single node or multi-node 10B - 100B Single Node Fit? Yes (w/ FSDP) FSDP + TP within node TP=2-8, FSDP across nodes No Add Pipeline Parallel TP=8 + PP=2-4 + FSDP > 100B Dense or MoE? Dense Full 3D Parallelism TP=8 + PP=8-16 + DP 1000s of GPUs MoE 3D + Expert Parallel EP + TP + PP + DP Expert = model parallel Key Considerations Bandwidth Priority 1. TP: Needs NVLink (600+ GB/s) 2. PP: Can use IB (200 GB/s) 3. DP: Lowest priority (overlap) 4. EP: All-to-All needs care Memory Efficiency • Activation checkpointing • Mixed precision (BF16/FP8) • Optimizer state sharding • Sequence parallelism Efficiency Targets • MFU > 40% (good) • MFU > 50% (excellent) • Bubble < 10% for PP • Comm overlap > 80%

12.6 Summary: Communication Patterns at a Glance

Strategy Collective When Volume Overlap?
DDP AllReduce (gradients) End of backward 2N ✓ Yes
FSDP/ZeRO-3 AllGather + ReduceScatter Every layer 3N Partial
Tensor Parallel AllReduce (activations) Every layer fwd+bwd 4×L×BSH ✗ Critical path
Pipeline Parallel P2P Send/Recv Stage boundaries 2×M×BSH Partial
Expert Parallel All-to-All MoE layers 4×BSH ✗ Critical path

13. Conclusion

Distributed training has evolved from simple data parallelism to sophisticated multi-dimensional strategies that enable training models with hundreds of billions of parameters. The key insights from this guide:

Key Takeaways

1. Understand your collectives:

  • AllReduce, AllGather, ReduceScatter, and All-to-All are the building blocks
  • Ring and tree algorithms trade off latency vs bandwidth
  • The α-β model helps predict communication costs

2. Match parallelism to your hardware:

  • TP within nodes (NVLink), PP across nodes (IB), DP anywhere
  • Communication hierarchy determines efficiency
  • Roofline analysis identifies bottlenecks

3. Scale systematically:

  • Start with DDP/FSDP for models that fit
  • Add TP when layers are too large
  • Add PP when model spans nodes
  • Consider MoE for parameter-efficient scaling

4. Optimize relentlessly:

  • Overlap communication with computation
  • Use mixed precision and activation checkpointing
  • Balance load across pipeline stages and experts
  • Profile, measure, iterate

The field continues to evolve rapidly. New techniques like sequence parallelism, context parallelism for long sequences, and novel MoE routing strategies push the boundaries of what's possible. Understanding the fundamentals covered here provides the foundation to adapt to these advances.

Further Reading

🚀 Now go train some massive models!

Remember: The best parallelism strategy is the one that maximizes throughput while staying within your memory budget.