- Introduction: The Scale Problem
- Collective Operations Fundamentals
- AllReduce Implementations
- AllGather, ReduceScatter, and All-to-All
- The α-β Performance Model
- Communication Roofline Analysis
- Data Parallelism (DDP)
- Fully Sharded Data Parallel (FSDP/ZeRO)
- Tensor Parallelism
- Pipeline Parallelism
- 3D Parallelism & Expert Parallelism
- Optimization Strategies & Case Studies
1. Introduction: The Scale Problem
The deep learning community is engaged in an unprecedented scaling race. Model sizes have grown from millions to hundreds of billions of parameters in just a few years, with the largest models now exceeding a trillion parameters. This exponential growth has fundamentally changed how we train neural networks—what once fit comfortably on a single GPU now requires orchestrating thousands of accelerators across multiple nodes.
Understanding distributed training is no longer optional for ML practitioners. Whether you're fine-tuning a 7B parameter model or training a frontier model from scratch, you need to understand the communication primitives, parallelism strategies, and performance trade-offs that make large-scale training possible.
1.1 The Exponential Growth of Model Size
Let's trace the evolution of model scale to understand why distributed training became necessary:
This growth isn't just about bragging rights—larger models consistently demonstrate better performance on benchmarks and real-world tasks. The scaling laws discovered by Kaplan et al. and refined by Hoffmann et al. (Chinchilla) show that model performance improves predictably with scale, following power-law relationships:
Where $L$ is the loss, $N$ is the number of parameters, and $\alpha \approx 0.076$ for language models. This means that to halve the loss, you need to increase model size by approximately $2^{1/0.076} \approx 8000×$. The economic incentive to scale is clear, but it creates an equally clear engineering challenge.
1.2 The Memory Wall
The fundamental constraint in training large models is memory. Let's break down exactly where the memory goes during training. For a model with $N$ parameters using mixed-precision training with the Adam optimizer:
Parameters (fp16): $2N$ bytes
Gradients (fp16): $2N$ bytes
Adam optimizer states:
• Momentum (fp32): $4N$ bytes
• Variance (fp32): $4N$ bytes
• Master weights (fp32): $4N$ bytes
Total: $2N + 2N + 4N + 4N + 4N = 16N$ bytes
That's 16 bytes per parameter just for the model state—before we even consider activations! Let's see what this means for real models:
| Model | Parameters | Model State (16N) | A100 80GB GPUs Needed |
|---|---|---|---|
LLaMA-7B |
7B | 112 GB | 2+ |
LLaMA-13B |
13B | 208 GB | 3+ |
LLaMA-65B |
65B | 1.04 TB | 13+ |
GPT-3 |
175B | 2.8 TB | 35+ |
PaLM |
540B | 8.6 TB | 108+ |
And we haven't even counted activation memory yet! Activations are the intermediate tensors computed during the forward pass and needed for the backward pass. For a transformer with:
- $L$ layers
- Hidden dimension $h$
- Sequence length $s$
- Batch size $b$
- Attention heads $a$
The activation memory scales approximately as:
For a GPT-3 scale model with $s=2048$, $b=1$, $h=12288$, $a=96$, and $L=96$ layers:
A single training sample for GPT-3 requires ~275GB just for activations. This is why activation checkpointing (recomputation) is essential—it trades compute for memory by recomputing activations during the backward pass instead of storing them all.
1.3 Training Bottlenecks: Compute, Memory, and Communication
Distributed training introduces a new dimension of complexity: communication. When we spread training across multiple devices, we must constantly exchange data between them. The three fundamental bottlenecks in distributed training are:
Compute Bound
Metric: TFLOPS utilization
Solution: Maximize arithmetic intensity, use tensor cores
Memory Bound
Metric: GB used vs available
Solution: Shard model, checkpoint activations
Communication Bound
Metric: GB/s, messages/sec
Solution: Overlap, compress, minimize volume
The interplay between these bottlenecks determines which distributed training strategy is optimal for a given model, hardware configuration, and batch size. This is why understanding performance modeling—which we'll cover in depth—is crucial for efficient distributed training.
1.4 The Distributed Training Landscape
To address these challenges, the community has developed a rich toolkit of parallelism strategies. Each addresses different bottlenecks and has different trade-offs:
| Strategy | What's Parallelized | Memory Savings | Communication Pattern | When to Use |
|---|---|---|---|---|
| Data Parallel | Samples across batches | None (replicated) | AllReduce on gradients | Model fits in GPU |
| FSDP/ZeRO | Model state | Linear in #GPUs | AllGather + ReduceScatter | Model state too large |
| Tensor Parallel | Layer operations | Linear in TP degree | AllReduce + AllGather | Layers too large |
| Pipeline Parallel | Layers across GPUs | Linear in PP degree | Point-to-point | Many layers, hide latency |
| Expert Parallel | MoE experts | Linear in #experts/GPU | All-to-All | MoE architectures |
In practice, training frontier models requires combining multiple strategies—known as 3D parallelism or even 4D/5D parallelism when including expert parallelism and sequence parallelism. Understanding how these strategies compose and when to use each is the key to efficient large-scale training.
This guide takes a first-principles approach to distributed training. We start with the collective operations that underpin all distributed training—understanding not just what they do, but how they're implemented internally (ring, tree, hierarchical). We then develop a rigorous performance analysis framework using the α-β model that lets us predict and optimize communication costs. Finally, we apply this framework to analyze each parallelism strategy, understanding exactly when each is optimal and how to combine them effectively.
1.5 A Word on Hardware
Throughout this guide, we'll reference specific hardware capabilities. Here's a quick reference for modern training hardware:
NVIDIA H100 SXM
• 989 TFLOPS TF32
• 1979 TFLOPS FP16/BF16
• 900 GB/s NVLink
NVIDIA A100 SXM
• 312 TFLOPS TF32
• 624 TFLOPS FP16/BF16
• 600 GB/s NVLink
InfiniBand NDR
• ~1-2μs latency
• RDMA support
• 8 ports per DGX node
Understanding these hardware specifications is essential for performance modeling. Notice the significant gap between intra-node bandwidth (NVLink: 600-900 GB/s) and inter-node bandwidth (InfiniBand: 50-400 GB/s). This asymmetry fundamentally shapes how we design distributed training systems.
2. Collective Operations Fundamentals
Before diving into parallelism strategies, we must understand the fundamental building blocks of distributed communication: collective operations. These are communication patterns that involve all processes in a group, enabling coordinated data exchange. Every distributed training algorithm is ultimately composed of these primitives.
The Message Passing Interface (MPI) standard defines these operations, and libraries like NCCL (NVIDIA Collective Communications Library), Gloo, and others implement them with hardware-specific optimizations. Understanding these operations—and how they're implemented—is essential for reasoning about distributed training performance.
2.1 Point-to-Point Operations: The Foundation
Before collective operations, let's briefly cover point-to-point communication—the simplest form of message passing where one process sends data directly to another:
import torch.distributed as dist
# Send operation (on rank 0)
if dist.get_rank() == 0:
tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda')
dist.send(tensor, dst=1)
# Receive operation (on rank 1)
if dist.get_rank() == 1:
tensor = torch.zeros(3, device='cuda')
dist.recv(tensor, src=0)
# tensor is now [1.0, 2.0, 3.0]
Point-to-point operations have straightforward performance characteristics:
Where $\alpha$ is the latency (time to initiate communication) and $\beta$ is the inverse bandwidth (time per byte). We'll formalize this model in Section 5.
2.2 Broadcast
Broadcast sends data from one process (the root) to all other processes in the group. After a broadcast, every process has an identical copy of the data.
import torch.distributed as dist
# Initialize tensor on all ranks
if dist.get_rank() == 0:
tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda')
else:
tensor = torch.zeros(3, device='cuda')
# Broadcast from rank 0 to all ranks
dist.broadcast(tensor, src=0)
# Now all ranks have tensor = [1.0, 2.0, 3.0]
Broadcast is used to distribute model parameters from rank 0 to all workers at initialization, and to synchronize hyperparameters or random seeds across processes.
2.3 Reduce
Reduce combines data from all processes using a reduction operation (sum, max, min, product, etc.) and stores the result on a single root process. It's the inverse of broadcast.
# Each rank has local gradients
local_gradients = compute_gradients(model, batch)
# Reduce sum to rank 0
dist.reduce(local_gradients, dst=0, op=dist.ReduceOp.SUM)
# Only rank 0 has the sum of all gradients
if dist.get_rank() == 0:
local_gradients /= dist.get_world_size() # Average
2.4 Scatter
Scatter distributes different chunks of data from the root process to all processes. If the root has data of size $n \times P$, each process receives a chunk of size $n$.
# Root prepares data chunks for each rank
if dist.get_rank() == 0:
scatter_list = [torch.tensor([i * 10.0], device='cuda')
for i in range(dist.get_world_size())]
else:
scatter_list = None
# Each rank receives its chunk
output = torch.zeros(1, device='cuda')
dist.scatter(output, scatter_list, src=0)
# Rank 0: [0.0], Rank 1: [10.0], Rank 2: [20.0], ...
2.5 Gather
Gather is the inverse of scatter—it collects data from all processes and concatenates it on the root process.
# Each rank has local data
local_tensor = torch.tensor([dist.get_rank() * 10.0], device='cuda')
# Gather to rank 0
if dist.get_rank() == 0:
gather_list = [torch.zeros(1, device='cuda')
for _ in range(dist.get_world_size())]
else:
gather_list = None
dist.gather(local_tensor, gather_list, dst=0)
# Rank 0: gather_list = [[0.0], [10.0], [20.0], [30.0]]
2.6 Summary of Basic Collectives
These four operations—broadcast, reduce, scatter, and gather—form the foundation. Notice the duality relationships:
| Operation | Data Flow | Root Has | Others Have | Dual Operation |
|---|---|---|---|---|
| Broadcast | Root → All | Original data | Copy of root's data | Reduce |
| Reduce | All → Root | Combined result | Original (unchanged) | Broadcast |
| Scatter | Root → All (different) | One chunk | Unique chunk each | Gather |
| Gather | All → Root | All chunks concatenated | Original (unchanged) | Scatter |
The operations we'll cover next—AllReduce, AllGather, ReduceScatter, and AllToAll—are "all" variants where the result is distributed to all processes, not just a root. These are the workhorses of distributed training because they eliminate the root bottleneck.
2.7 Visualizing the Collective Operation Family
Root-based operations create a serialization bottleneck—all data must flow through one process. In distributed training with hundreds of GPUs, this becomes catastrophic. The "All" variants (AllReduce, AllGather, ReduceScatter) distribute the work and achieve bandwidth-optimal scaling when implemented correctly. This is why NCCL focuses heavily on optimizing these operations.
3. AllReduce Implementations
AllReduce is the most important collective operation in distributed training. It combines data from all processes using a reduction operation and distributes the result back to everyone. After an AllReduce, every process has an identical copy of the reduced data.
In distributed training, AllReduce is used to synchronize gradients across all workers. Each worker computes gradients on its local batch, then AllReduce sums them (after which you divide by the world size to get the average gradient).
import torch.distributed as dist
# Each rank computes local gradients
loss = model(inputs).sum()
loss.backward()
# AllReduce to sum gradients across all ranks
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= dist.get_world_size() # Average
# Now all ranks have identical averaged gradients
optimizer.step()
The implementation of AllReduce dramatically affects training throughput. Let's examine four increasingly sophisticated implementations.
3.1 Naive AllReduce (Reduce + Broadcast)
The simplest AllReduce implementation is just a Reduce to a root followed by a Broadcast from that root. While conceptually simple, this approach has terrible scaling properties.
Using the α-β model (latency α, inverse bandwidth β):
Reduce phase: Root receives from (P-1) processes: $(P-1)\alpha + (P-1)n\beta$
Broadcast phase: Root sends to (P-1) processes: $(P-1)\alpha + (P-1)n\beta$
Total: $T_{\text{naive}} = 2(P-1)\alpha + 2(P-1)n\beta$
The $2(P-1)n\beta$ bandwidth term is catastrophic—root must handle all data volume!
3.2 Tree-Based AllReduce
Tree-based algorithms reduce the bottleneck at the root by organizing processes into a tree structure. In a binary tree reduction, each node combines data from two children, then passes the result to its parent.
Tree Reduce: $T = \log_2(P) \cdot (\alpha + n\beta)$
Tree Broadcast: $T = \log_2(P) \cdot (\alpha + n\beta)$
Total Tree AllReduce: $T_{\text{tree}} = 2\log_2(P)\alpha + 2\log_2(P)n\beta$
Better latency scaling (logarithmic), but the bandwidth term $2\log_2(P)n\beta$ is still
suboptimal—we're sending the full data at each level.
3.3 Ring AllReduce
Ring AllReduce is the breakthrough algorithm that achieves bandwidth-optimal scaling. Instead of aggregating all data at a root, it organizes processes in a ring and pipelines the communication so that each link carries only $n/P$ data at a time.
Ring AllReduce is implemented as two phases:
- ReduceScatter phase: Each process ends up with 1/P of the fully reduced data
- AllGather phase: Distribute the reduced chunks so everyone has the full result
ReduceScatter phase: $(P-1)$ steps, each sending $n/P$ bytes
$T_{\text{reduce-scatter}} = (P-1)\alpha + (P-1) \cdot \frac{n}{P} \cdot \beta = (P-1)\alpha + \frac{(P-1)n}{P}\beta$
AllGather phase: $(P-1)$ steps, each sending $n/P$ bytes
$T_{\text{all-gather}} = (P-1)\alpha + \frac{(P-1)n}{P}\beta$
Total Ring AllReduce:
$$T_{\text{ring}} = 2(P-1)\alpha + \frac{2(P-1)n}{P}\beta \approx 2(P-1)\alpha + 2n\beta$$
The bandwidth term is $\frac{2(P-1)}{P}n\beta \approx 2n\beta$ — constant regardless of P!
def ring_allreduce(tensor, rank, world_size):
"""Ring AllReduce: ReduceScatter + AllGather"""
chunk_size = tensor.numel() // world_size
chunks = tensor.chunk(world_size)
# Phase 1: ReduceScatter
for step in range(world_size - 1):
send_idx = (rank - step) % world_size
recv_idx = (rank - step - 1) % world_size
# Send chunk[send_idx] to next rank in ring
send_to_next = chunks[send_idx].clone()
recv_from_prev = torch.zeros_like(chunks[recv_idx])
# Simultaneous send/recv
all_to_all([send_to_next], [recv_from_prev],
src=(rank - 1) % world_size,
dst=(rank + 1) % world_size)
# Accumulate received chunk
chunks[recv_idx] += recv_from_prev
# After ReduceScatter: each rank has one fully reduced chunk
# Phase 2: AllGather
for step in range(world_size - 1):
send_idx = (rank - step + 1) % world_size
recv_idx = (rank - step) % world_size
# Same ring pattern, but just copy (no reduction)
all_to_all([chunks[send_idx]], [chunks[recv_idx]], ...)
return torch.cat(chunks)
3.4 Recursive Halving-Doubling (Rabenseifner's Algorithm)
For small messages where latency dominates, the recursive halving-doubling algorithm (also known as Rabenseifner's algorithm) achieves better latency than ring while maintaining bandwidth efficiency. It's optimal when $P$ is a power of 2.
3.5 Comparison: When to Use Each Algorithm
| Algorithm | Latency Term | Bandwidth Term | Best For |
|---|---|---|---|
| Naive (Reduce+Bcast) | $2(P-1)\alpha$ | $2(P-1)n\beta$ | Never use at scale |
| Tree | $2\log_2(P)\alpha$ | $2\log_2(P)n\beta$ | Very small messages |
| Ring | $2(P-1)\alpha$ | $\frac{2(P-1)}{P}n\beta \approx 2n\beta$ | Large messages |
| Recursive H-D | $2\log_2(P)\alpha$ | $\frac{2(P-1)}{P}n\beta \approx 2n\beta$ | Medium messages, power-of-2 P |
NCCL automatically selects the best algorithm based on message size, number of processes, and hardware topology. For large gradient tensors (MBs to GBs) typical in deep learning, ring-based algorithms dominate. For small control messages or when P is large relative to message size, tree or recursive halving-doubling may be chosen.
3.6 Hierarchical AllReduce
Modern clusters have non-uniform communication costs—intra-node (NVLink) is much faster than inter-node (InfiniBand). Hierarchical AllReduce exploits this by performing local reductions within nodes first, then a global AllReduce across nodes.
Consider 8 nodes × 8 GPUs = 64 GPUs total. Flat ring AllReduce: 2×63 steps over mixed fast/slow links. Hierarchical: 2×7 fast NVLink steps + 2×7 slow IB steps + 1 broadcast. The slow IB link sees only N-1=7 steps instead of ~63, dramatically reducing the impact of the slower interconnect.
4. AllGather, ReduceScatter, and All-to-All
Beyond AllReduce, three other "all" collective operations are essential for distributed training. Each has distinct communication patterns and use cases in different parallelism strategies.
4.1 AllGather
AllGather collects data from all processes and distributes the concatenated result to everyone. If each process starts with $n$ bytes, every process ends with $n \times P$ bytes containing all the data.
import torch.distributed as dist
# Each rank has a local tensor (e.g., model shard)
local_tensor = torch.randn(1000, device='cuda') # n = 1000 elements
# Prepare output list for gathered tensors
world_size = dist.get_world_size()
gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]
# AllGather: collect all shards
dist.all_gather(gathered, local_tensor)
# Result: list of P tensors, each 1000 elements
full_tensor = torch.cat(gathered) # Shape: [P * 1000]
# Alternative: all_gather_into_tensor (more efficient)
output = torch.zeros(world_size * 1000, device='cuda')
dist.all_gather_into_tensor(output, local_tensor) # Direct into contiguous buffer
FSDP/ZeRO-3: Gather sharded model weights before forward pass
Tensor Parallelism: Gather partial activations after column-parallel linear layers
Sequence Parallelism: Reconstruct full sequence after parallel attention
Ring AllGather Implementation
Like Ring AllReduce, Ring AllGather achieves bandwidth-optimal performance by pipelining communication around a ring. Each step, processes send their chunk to the next neighbor and receive from the previous neighbor.
4.2 ReduceScatter
ReduceScatter is the inverse of AllGather: it reduces data from all processes and distributes different chunks of the result to each process. If each process starts with $n \times P$ bytes, every process ends with a different $n$-byte chunk of the reduced result.
import torch.distributed as dist
# Each rank has gradients for ALL parameters (before sharding)
full_gradients = torch.randn(4000, device='cuda') # n*P elements
# Output: each rank gets 1/P of the reduced gradients
world_size = dist.get_world_size()
output_size = full_gradients.size(0) // world_size
reduced_shard = torch.zeros(output_size, device='cuda')
# ReduceScatter: sum gradients across ranks, distribute shards
dist.reduce_scatter_tensor(reduced_shard, full_gradients, op=dist.ReduceOp.SUM)
# Rank 0 gets sum of elements [0:1000] from all ranks
# Rank 1 gets sum of elements [1000:2000] from all ranks
# etc.
FSDP/ZeRO-2+3: After backward pass, reduce gradients and scatter to shard owners
Ring AllReduce: First phase—produces the partially reduced chunks
Tensor Parallelism: After row-parallel linear, reduce-scatter activations
4.3 All-to-All (AllToAll)
All-to-All is the most general collective: each process sends different data to every other process. It's essentially a distributed transpose operation. Process $i$ sends chunk $j$ to process $j$, and receives chunk $i$ from process $j$.
import torch.distributed as dist
# Each rank has P chunks to send (one to each rank)
world_size = dist.get_world_size()
rank = dist.get_rank()
# Input: chunks destined for each rank
send_chunks = [torch.randn(100, device='cuda') for _ in range(world_size)]
# Output: chunks received from each rank
recv_chunks = [torch.zeros(100, device='cuda') for _ in range(world_size)]
# All-to-All: exchange chunks
dist.all_to_all(recv_chunks, send_chunks)
# recv_chunks[j] now contains what rank j sent to us (their chunk for us)
# More efficient: all_to_all_single for contiguous tensors
input_tensor = torch.randn(world_size * 100, device='cuda') # [P * chunk_size]
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor, input_tensor)
Expert Parallelism (MoE): Route tokens to experts across GPUs—each token goes
to its assigned expert, and All-to-All redistributes them
Sequence Parallelism: Transpose between sequence-sharded and head-sharded layouts
in attention
FFT/Convolutions: Distributed FFT requires All-to-All for global transpose
4.4 Collective Operations Complexity Summary
Here's a comprehensive table summarizing the time complexity of all collective operations using optimal (bandwidth-efficient) algorithms:
| Operation | Latency Term | Bandwidth Term | Input/Output Size | Algorithm |
|---|---|---|---|---|
| Broadcast | $\log_2(P) \cdot \alpha$ | $n\beta$ | n → n (all) | Binomial tree |
| Reduce | $\log_2(P) \cdot \alpha$ | $n\beta$ | n (all) → n (root) | Binomial tree |
| Scatter | $\log_2(P) \cdot \alpha$ | $\frac{(P-1)n}{P}\beta$ | nP (root) → n (each) | Binomial tree |
| Gather | $\log_2(P) \cdot \alpha$ | $\frac{(P-1)n}{P}\beta$ | n (each) → nP (root) | Binomial tree |
| AllReduce | $2\log_2(P) \cdot \alpha$ | $\frac{2(P-1)n}{P}\beta$ | n (each) → n (all, same) | Recursive HD / Ring |
| AllGather | $\log_2(P) \cdot \alpha$ | $\frac{(P-1)n}{P} \cdot P \cdot \beta$ | n (each) → nP (all) | Ring / Recursive |
| ReduceScatter | $\log_2(P) \cdot \alpha$ | $\frac{(P-1)n}{P}\beta$ | nP (each) → n (each, different) | Ring / Recursive |
| All-to-All | $(P-1) \cdot \alpha$ | $\frac{(P-1)n}{P}\beta$ | nP (each) → nP (each, shuffled) | Direct / Pairwise |
The bandwidth term represents how much data flows through each link in the
optimal case. Notice these key relationships:
• AllReduce bandwidth = 2 × ReduceScatter bandwidth (because AllReduce =
ReduceScatter + AllGather)
• AllGather output is P× larger than input, so its effective bandwidth
use is $(P-1)n\beta$ total
• All-to-All has the worst latency $(P-1)\alpha$ because in the worst case
every process must directly communicate with every other process
4.5 Choosing the Right Collective
Different distributed training strategies use different collective operations:
| Training Strategy | Primary Collectives | Communication Pattern |
|---|---|---|
| Data Parallel (DDP) | AllReduce | Synchronize gradients after backward |
| FSDP/ZeRO-3 | AllGather, ReduceScatter | Gather weights before forward, scatter gradients after backward |
| Tensor Parallel | AllReduce, AllGather | Reduce activations after parallel matmuls |
| Pipeline Parallel | Point-to-Point (Send/Recv) | Pass activations between pipeline stages |
| Expert Parallel (MoE) | All-to-All | Route tokens to/from experts |
| Sequence Parallel | AllGather, ReduceScatter, All-to-All | Redistribute sequence chunks for attention |
When analyzing distributed training efficiency, always consider the total communication
volume per iteration. For a model with $N$ parameters:
• DDP: 2N bytes (AllReduce gradients once)
• FSDP: 3N bytes per layer × num_layers (AllGather forward + AllGather backward +
ReduceScatter gradients)
• Tensor Parallel: Proportional to activation size, not model size
Understanding these patterns is key to optimizing distributed training configurations!
5. The α-β Performance Model
To analyze and optimize collective communication, we need a principled model. The α-β model (also called the postal model or Hockney model) decomposes communication time into two fundamental components: a fixed startup cost and a data-dependent transfer cost.
5.1 Model Fundamentals
The time to send a message of $n$ bytes between two processes is:
Where:
- α (alpha) — Latency: The fixed startup cost in seconds. This includes software overhead (driver calls, kernel transitions), network protocol setup, and physical signal propagation delay. Independent of message size.
- β (beta) — Inverse Bandwidth: The time to transmit one byte, in seconds/byte. If your link has bandwidth $B$ bytes/second, then $\beta = 1/B$.
- n — Message Size: The amount of data being transferred in bytes.
If a link has bandwidth $B = 450$ GB/s (like NVLink 4.0):
$\beta = \frac{1}{B} = \frac{1}{450 \times 10^9 \text{ B/s}} = 2.22 \times 10^{-12} \text{ s/byte} = 2.22 \text{ ps/byte}$
For a 1 GB message: $T = \alpha + 1 \times 10^9 \times 2.22 \times 10^{-12} = \alpha + 0.00222\text{s} \approx \alpha + 2.2\text{ms}$
5.2 Extending to Collective Operations
For collective operations involving $P$ processes, the model extends to account for multiple communication steps and the algorithm's structure:
Where:
- k(P): Number of communication rounds (latency multiplier)
- m(n, P): Total bytes transmitted per process (bandwidth term)
5.3 The Crossover Point
A critical question: when does latency matter vs. bandwidth? The crossover point $n^*$ is where both terms contribute equally:
For messages smaller than $n^*$, latency dominates. For larger messages, bandwidth dominates.
def calculate_crossover(alpha_us, bandwidth_gbps):
"""
Calculate the crossover point where latency = bandwidth term.
Args:
alpha_us: Latency in microseconds
bandwidth_gbps: Bandwidth in GB/s
Returns:
Crossover point in bytes
"""
alpha_s = alpha_us * 1e-6 # Convert to seconds
beta = 1 / (bandwidth_gbps * 1e9) # seconds per byte
n_star = alpha_s / beta # crossover in bytes
return n_star
# Example calculations
interconnects = {
"NVLink 4.0": (1.0, 450), # 1 μs latency, 450 GB/s
"InfiniBand HDR": (1.5, 200), # 1.5 μs, 200 GB/s
"100GbE RoCE": (5.0, 12.5), # 5 μs, 12.5 GB/s
"TCP/IP Ethernet": (50, 12.5), # 50 μs (with kernel), 12.5 GB/s
}
print("Interconnect Crossover Points:")
print("-" * 45)
for name, (alpha, bw) in interconnects.items():
n_star = calculate_crossover(alpha, bw)
if n_star < 1024:
print(f"{name:20}: n* = {n_star:8.0f} B")
elif n_star < 1024**2:
print(f"{name:20}: n* = {n_star/1024:8.1f} KB")
else:
print(f"{name:20}: n* = {n_star/(1024**2):8.1f} MB")
# Output:
# Interconnect Crossover Points:
# ---------------------------------------------
# NVLink 4.0 : n* = 450.0 KB
# InfiniBand HDR : n* = 300.0 KB
# 100GbE RoCE : n* = 62.5 KB
# TCP/IP Ethernet : n* = 625.0 KB
5.4 Practical Implications for Distributed Training
The α-β model has profound implications for how we design distributed training systems:
5.5 Extended α-β-γ Model
For operations that involve computation (like Reduce operations), we can extend the model to include a computation term:
Where γ (gamma) is the time to perform one computational operation (e.g., one floating-point addition). For modern GPUs, γ is typically negligible because:
- GPUs can perform reductions at near-memory bandwidth (TF-class FLOPS)
- NCCL fuses reduction with communication
- The communication time (β term) dominates
The computation term γ becomes significant in these cases:
• CPU-based training: CPU arithmetic is much slower than network bandwidth
• Custom reduction operations: Complex user-defined reductions
• Compression/decompression: Gradient compression trades γ for reduced β
5.6 Measuring α and β in Practice
To optimize your distributed training, you need to know the actual α and β values for your hardware. Here's how to measure them:
import torch
import torch.distributed as dist
import time
import numpy as np
def measure_allreduce_alpha_beta(sizes_bytes, num_warmup=5, num_iters=20):
"""
Measure α and β for AllReduce by fitting T = α + n*β to timing data.
Run with: torchrun --nproc_per_node=N measure_alpha_beta.py
"""
times = []
for size in sizes_bytes:
numel = size // 4 # FP32 elements
tensor = torch.randn(numel, device='cuda')
# Warmup
for _ in range(num_warmup):
dist.all_reduce(tensor)
torch.cuda.synchronize()
# Timed runs
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_iters):
dist.all_reduce(tensor)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / num_iters
times.append(elapsed)
# Linear regression: T = α + n*β
# For AllReduce: T = 2*log2(P)*α + 2*(P-1)/P * n * β
sizes = np.array(sizes_bytes)
times = np.array(times)
# Simple linear fit (ignoring the exact formula for illustration)
# slope = β, intercept ≈ α
coeffs = np.polyfit(sizes, times, 1)
beta = coeffs[0]
alpha = coeffs[1]
bandwidth_gbps = 1 / beta / 1e9
return alpha, beta, bandwidth_gbps
if __name__ == "__main__":
dist.init_process_group()
# Test various message sizes from 1KB to 1GB
sizes = [1024 * (2 ** i) for i in range(20)] # 1KB to 1GB
alpha, beta, bw = measure_allreduce_alpha_beta(sizes)
if dist.get_rank() == 0:
print(f"Measured Parameters:")
print(f" α (latency) = {alpha*1e6:.2f} μs")
print(f" β (inv. bw) = {beta*1e12:.2f} ps/byte")
print(f" Bandwidth = {bw:.1f} GB/s")
print(f" Crossover n* = {alpha/beta/1024:.1f} KB")
dist.destroy_process_group()
The simple α-β model doesn't capture everything. Real systems have additional overheads:
• GPU kernel launch: ~10-50 μs per collective call
• Memory allocation: First call may allocate buffers
• Synchronization: Barrier costs at collective boundaries
• Congestion: Shared network links reduce effective bandwidth
Always measure on your actual hardware configuration!
6. The Communication Roofline Model
The classic Roofline Model relates computational performance to memory bandwidth through arithmetic intensity. We can adapt this powerful framework to analyze distributed training by creating a Communication Roofline that relates training throughput to network bandwidth through the computation-to-communication ratio.
6.1 From Compute Roofline to Communication Roofline
Recall the traditional roofline model: performance is bounded by either peak compute or memory bandwidth, depending on arithmetic intensity (FLOPS/byte). For distributed training, we create an analogous model:
6.2 Defining the Computation-to-Communication Ratio
For distributed training, we define the computation-to-communication ratio (CCR) as the ratio of computation performed to bytes communicated:
For standard data-parallel training with gradient AllReduce:
Where $N$ is the number of model parameters and the factor of 2 accounts for AllReduce sending/receiving approximately $2N$ bytes (ring algorithm).
def calculate_ccr_ddp(model_params, batch_size, seq_len, dtype_bytes=2):
"""
Calculate computation-to-communication ratio for DDP training.
For transformers: FLOPs ≈ 6 * N * tokens (forward + backward)
Communication: 2 * N * dtype_bytes (AllReduce gradients)
"""
tokens = batch_size * seq_len
flops = 6 * model_params * tokens # Approximate transformer FLOPs
comm_bytes = 2 * model_params * dtype_bytes # AllReduce volume
ccr = flops / comm_bytes
return ccr
# Example: GPT-3 175B with different batch sizes
model_params = 175e9
seq_len = 2048
print("GPT-3 175B CCR Analysis:")
print("-" * 50)
for batch_size in [1, 4, 16, 64, 256]:
ccr = calculate_ccr_ddp(model_params, batch_size, seq_len)
print(f"Batch size {batch_size:4}: CCR = {ccr:8.1f} FLOPS/byte")
# Output:
# GPT-3 175B CCR Analysis:
# --------------------------------------------------
# Batch size 1: CCR = 3072.0 FLOPS/byte
# Batch size 4: CCR = 12288.0 FLOPS/byte
# Batch size 16: CCR = 49152.0 FLOPS/byte
# Batch size 64: CCR = 196608.0 FLOPS/byte
# Batch size 256: CCR = 786432.0 FLOPS/byte
6.3 The Communication Roofline Diagram
Now we can construct the full communication roofline. The y-axis shows achieved throughput (e.g., samples/second), and we have two ceilings:
- Communication ceiling (sloped): Throughput limited by network bandwidth
- Compute ceiling (flat): Throughput limited by GPU compute power
The ridge point CCR* tells you the minimum computation-to-communication ratio needed
to fully utilize your GPUs:
$$\text{CCR}^* = \frac{\text{Peak GPU FLOPS}}{\text{Aggregate Network Bandwidth}}$$
Example (8× H100 node):
• Peak FP16: 8 × 1979 TFLOPS = 15.8 PFLOPS = 15.8 × 10¹⁵ FLOPS
• NVLink bandwidth: 900 GB/s per GPU = 7.2 TB/s aggregate
• CCR* = 15.8 × 10¹⁵ / 7.2 × 10¹² ≈ 2,200 FLOPS/byte
You need at least CCR > 2,200 to be compute-bound on this hardware!
6.4 Multi-Level Communication Roofline
Modern GPU clusters have multiple levels of interconnect with vastly different bandwidths. This creates a staircase roofline with multiple ceilings:
6.5 Scaling Efficiency from the Roofline
The communication roofline directly tells us the scaling efficiency of distributed training. If you're communication-bound:
When $\text{CCR} < \text{CCR}^*$, efficiency scales linearly with batch size (since CCR ∝ batch size for fixed model). When $\text{CCR} > \text{CCR}^*$, you've hit the compute ceiling—efficiency is 100% and won't improve with larger batches.
6.6 Using the Roofline for Optimization
The communication roofline provides a systematic framework for optimization decisions:
| Current Position | Diagnosis | Optimization Strategy |
|---|---|---|
| Far below comm ceiling (CCR ≪ CCR*) |
Severely communication-bound |
• Increase batch size • Gradient compression • Better interconnect • Reduce communication volume |
| On comm ceiling (CCR < CCR*) |
Communication-bound, scaling linearly |
• Increase batch size (best ROI) • Overlap comm with compute • Use hierarchy-aware algorithms |
| Near ridge (CCR ≈ CCR*) |
Transitioning to compute-bound |
• Small batch increase helps • Focus on compute optimizations • Best scaling sweet spot |
| On compute ceiling (CCR > CCR*) |
Compute-bound (ideal!) |
• GPU kernel optimizations • Mixed precision training • Operator fusion • Better batch/GPU utilization |
def analyze_roofline_position(
model_flops_per_sample: float,
model_params: float,
batch_size: int,
gpu_peak_tflops: float,
network_bw_gbps: float,
dtype_bytes: int = 2,
comm_factor: float = 2.0, # AllReduce ≈ 2x model size
):
"""
Determine where a workload sits on the communication roofline.
Returns analysis with optimization recommendations.
"""
# Calculate CCR
total_flops = model_flops_per_sample * batch_size
comm_bytes = comm_factor * model_params * dtype_bytes
ccr = total_flops / comm_bytes
# Calculate ridge point
ccr_star = (gpu_peak_tflops * 1e12) / (network_bw_gbps * 1e9)
# Theoretical throughput limits
compute_limit = (gpu_peak_tflops * 1e12) / model_flops_per_sample # samples/s
comm_limit = (network_bw_gbps * 1e9) / (comm_bytes / batch_size) # samples/s
# Effective throughput (min of both limits)
effective_throughput = min(compute_limit, comm_limit)
efficiency = effective_throughput / compute_limit * 100
# Determine regime
if ccr < ccr_star * 0.2:
regime = "SEVERELY_COMM_BOUND"
recommendation = "Increase batch size by 4x+ or use gradient compression"
elif ccr < ccr_star * 0.8:
regime = "COMM_BOUND"
recommendation = "Increase batch size or overlap communication"
elif ccr < ccr_star * 1.2:
regime = "NEAR_RIDGE"
recommendation = "Near optimal! Small batch increase may help"
else:
regime = "COMPUTE_BOUND"
recommendation = "Excellent! Focus on GPU optimizations"
return {
"ccr": ccr,
"ccr_star": ccr_star,
"regime": regime,
"efficiency": efficiency,
"recommendation": recommendation,
"critical_batch": batch_size * (ccr_star / ccr), # Batch to reach ridge
}
# Example analysis
result = analyze_roofline_position(
model_flops_per_sample=6 * 175e9 * 2048, # GPT-3 175B, seq=2048
model_params=175e9,
batch_size=32,
gpu_peak_tflops=1979, # H100 FP16
network_bw_gbps=900, # NVLink-4
)
print(f"CCR: {result['ccr']:.0f} FLOPS/byte")
print(f"Ridge CCR*: {result['ccr_star']:.0f} FLOPS/byte")
print(f"Regime: {result['regime']}")
print(f"Efficiency: {result['efficiency']:.1f}%")
print(f"Recommendation: {result['recommendation']}")
6.7 Practical Roofline Numbers
Here are typical ridge point values for common GPU configurations:
| Configuration | Peak TFLOPS (FP16) | Network BW | CCR* (Ridge) | Min Batch for 90% Eff |
|---|---|---|---|---|
| Single H100 (NVSwitch) | 1,979 | 900 GB/s | ~2,200 | ~16-32 (model dependent) |
| 8× H100 DGX (intra-node) | 15,832 | 7.2 TB/s | ~2,200 | ~128-256 |
| Multi-node H100 (IB NDR) | 15,832 × N | 400 GB/s/node | ~40,000 | ~2,000+ |
| 8× A100 (intra-node) | 2,496 | 4.8 TB/s | ~520 | ~64-128 |
| Cloud GPUs (Ethernet) | varies | 25-100 Gb/s | ~200,000+ | Very large! |
Notice the dramatic difference in CCR* between on-prem DGX clusters (CCR* ≈ 2,000) and
cloud instances with Ethernet (CCR* ≈ 200,000). This is why:
• Cloud training often requires 10-100× larger batch sizes for efficiency
• Gradient compression and async methods are essential for cloud
• Many teams use cloud for development but on-prem for production training
The communication roofline makes this tradeoff quantitatively clear!
7. Data Parallelism: The Foundation
Data Parallelism (DP) is the simplest and most widely used distributed training strategy. Each GPU holds a complete copy of the model, processes different data samples, and synchronizes gradients after backward pass. Let's explore its mechanics, optimizations, and PyTorch's sophisticated DDP implementation.
7.1 Basic Data Parallelism
The key insight: since all GPUs apply the same averaged gradient to identical model copies, the models remain synchronized without explicit parameter broadcasting.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank, world_size):
"""Initialize distributed process group."""
dist.init_process_group(
backend="nccl", # NVIDIA Collective Communication Library
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
def train_ddp(rank, world_size, model, dataloader, optimizer):
setup_ddp(rank, world_size)
# Wrap model with DDP
model = model.to(rank)
model = DDP(model, device_ids=[rank])
for batch in dataloader:
inputs, labels = batch
inputs, labels = inputs.to(rank), labels.to(rank)
optimizer.zero_grad()
# Forward pass (independent per GPU)
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass (DDP hooks trigger AllReduce)
loss.backward() # ← Gradients synchronized here!
# Optimizer step (identical on all GPUs)
optimizer.step()
dist.destroy_process_group()
# Launch with: torchrun --nproc_per_node=8 basic_ddp.py
7.2 PyTorch DDP: Under the Hood
PyTorch's DistributedDataParallel is highly optimized. It doesn't wait for the full backward pass to complete—instead, it overlaps gradient communication with computation using gradient bucketing.
7.3 DDP Configuration Options
from torch.nn.parallel import DistributedDataParallel as DDP
# Advanced DDP configuration
model = DDP(
model,
device_ids=[local_rank],
# Bucket size (bytes) - smaller = more overlap, more latency overhead
bucket_cap_mb=25, # Default: 25 MB per bucket
# Find unused parameters (for models with conditional execution)
find_unused_parameters=False, # Set True if some params don't get gradients
# Gradient as bucket view (memory optimization)
gradient_as_bucket_view=True, # Reduces memory by 1 copy
# Static graph optimization (PyTorch 1.11+)
static_graph=True, # Enable if computation graph doesn't change
)
# Tune bucket size based on your model and network
# - Smaller buckets: Better overlap, more latency overhead
# - Larger buckets: Less overhead, less overlap opportunity
# - Rule of thumb: bucket_size > α/β (crossover point)
7.4 Communication Volume Analysis
Let's analyze the communication requirements for DDP precisely:
The factor of 2 comes from the ring AllReduce algorithm: each gradient element is sent once during ReduceScatter and once during AllGather (approximately).
| Model | Parameters (N) | FP32 Comm | FP16/BF16 Comm | Time @ 200 GB/s |
|---|---|---|---|---|
| ResNet-50 | 25M | 200 MB | 100 MB | ~0.5 ms |
| BERT-Large | 340M | 2.7 GB | 1.4 GB | ~7 ms |
| GPT-2 XL | 1.5B | 12 GB | 6 GB | ~30 ms |
| Llama-7B | 7B | 56 GB | 28 GB | ~140 ms |
| Llama-70B | 70B | 560 GB | 280 GB | ~1.4 s |
DDP requires each GPU to hold:
• Model parameters: N × sizeof(dtype)
• Gradients: N × sizeof(dtype)
• Optimizer states: 2N (Adam: m, v) or more
Total for FP32 Adam: ~16N bytes per GPU
For mixed precision: ~18-20N bytes (FP16 params + FP32 master weights + FP32 optimizer)
A 7B model needs ~140 GB per GPU for full DDP training—exceeds a single 80GB H100!
7.5 Gradient Compression
For bandwidth-limited scenarios, gradient compression can reduce communication volume at the cost of some compute overhead:
import torch
import torch.distributed as dist
class TopKCompressor:
"""Top-K gradient sparsification with error feedback."""
def __init__(self, ratio=0.01):
self.ratio = ratio
self.error_feedback = {} # Accumulated residuals
def compress(self, name, grad):
# Add error feedback from previous iteration
if name in self.error_feedback:
grad = grad + self.error_feedback[name]
# Select Top-K elements
k = max(1, int(grad.numel() * self.ratio))
values, indices = torch.topk(grad.abs().flatten(), k)
# Store residual for next iteration
mask = torch.zeros_like(grad.flatten())
mask[indices] = 1
self.error_feedback[name] = grad.flatten() * (1 - mask)
self.error_feedback[name] = self.error_feedback[name].view_as(grad)
# Return sparse representation
return indices, grad.flatten()[indices]
def decompress(self, indices, values, shape):
grad = torch.zeros(shape, device=values.device).flatten()
grad[indices] = values
return grad.view(shape)
# Usage in training loop
compressor = TopKCompressor(ratio=0.01) # Keep top 1%
for name, param in model.named_parameters():
if param.grad is not None:
idx, vals = compressor.compress(name, param.grad)
# AllGather compressed gradients (much smaller!)
all_idx = [torch.zeros_like(idx) for _ in range(world_size)]
all_vals = [torch.zeros_like(vals) for _ in range(world_size)]
dist.all_gather(all_idx, idx)
dist.all_gather(all_vals, vals)
# Decompress and average
param.grad.zero_()
for i, v in zip(all_idx, all_vals):
param.grad.flatten()[i] += v / world_size
7.6 DDP Best Practices
1. Use NCCL backend for GPU training (fastest for NVIDIA GPUs)
2. Pin memory in DataLoader:
DataLoader(..., pin_memory=True, num_workers=4)
3. Use DistributedSampler:
sampler = DistributedSampler(dataset, shuffle=True)
4. Tune bucket size based on model and interconnect
5. Enable gradient_as_bucket_view (saves memory)
6. Use static_graph=True if computation graph is fixed
7. Consider mixed precision: Halves communication volume!
7.7 When DDP Isn't Enough
DDP has a fundamental limitation: each GPU must hold the entire model. This becomes impossible for large models:
For models larger than ~5-10B parameters (depending on GPU memory), we need model parallelism strategies that distribute the model itself across GPUs. The next section covers FSDP/ZeRO, which elegantly extends DDP with parameter sharding.
8. FSDP and ZeRO: Sharded Data Parallelism
ZeRO (Zero Redundancy Optimizer) from DeepSpeed and FSDP (Fully Sharded Data Parallel) from PyTorch solve DDP's memory problem by sharding model state across GPUs rather than replicating everything. This trades communication for memory, enabling training of models that don't fit on a single GPU.
8.1 The Redundancy Problem
In standard DDP, every GPU holds identical copies of:
- Model parameters (N × dtype_bytes)
- Gradients (N × dtype_bytes)
- Optimizer states (2-8× N bytes depending on optimizer)
8.2 ZeRO Stages Explained
ZeRO introduces three progressive stages of sharding, each reducing memory further at the cost of more communication:
8.3 FSDP/ZeRO-3 Execution Flow
ZeRO-3/FSDP requires careful orchestration of parameter gathering before computation and gradient scattering after:
8.4 Communication Cost Analysis
FSDP trades memory for communication. Let's analyze the cost:
Compared to DDP's 2N, FSDP has 1.5× more communication. However, the memory savings often outweigh this cost, and communication can be overlapped with compute.
| Strategy | Memory per GPU | Communication | Best For |
|---|---|---|---|
| DDP | ~18N bytes | 2N bytes | Small models that fit in GPU memory |
| ZeRO-1 | ~6N + 12N/P bytes | 2N bytes (same) | Medium models, free memory savings |
| ZeRO-2 | ~4N + 14N/P bytes | 2N bytes (same) | Larger models needing gradient sharding |
| ZeRO-3/FSDP | ~18N/P bytes | 3N bytes (+50%) | Very large models; memory-constrained |
import torch
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
# Define sharding strategy
sharding_strategy = ShardingStrategy.FULL_SHARD # ZeRO-3 equivalent
# Options: FULL_SHARD, SHARD_GRAD_OP (ZeRO-2), NO_SHARD (DDP)
# Mixed precision policy
mixed_precision = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
# Auto-wrap policy for transformers
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}, # Your layer class
)
# Wrap model with FSDP
model = FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
# Performance optimizations
use_orig_params=True, # Better compatibility with optimizers
limit_all_gathers=True, # Prevent OOM from too many concurrent gathers
)
# Training loop (same as regular PyTorch!)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch).loss
loss.backward()
optimizer.step()
8.5 Prefetching and Overlap
Modern FSDP implementations hide communication latency by prefetching the next layer's parameters while computing the current layer:
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
BackwardPrefetch,
ForwardPrefetch,
)
# Configure prefetching
model = FSDP(
model,
# Forward prefetching (gather next layer during current layer's compute)
forward_prefetch=True,
# Backward prefetching strategy
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # or BACKWARD_POST
# BACKWARD_PRE: Prefetch before backward (default, best for most cases)
# BACKWARD_POST: Prefetch after backward (use if memory-tight)
)
# For maximum overlap, also enable:
# - async_op=True in NCCL operations
# - CUDA streams for concurrent execution
8.6 Hybrid Sharding (HSDP)
For multi-node training, Hybrid Sharded Data Parallel (HSDP) combines FSDP within nodes and DDP across nodes. This leverages the fast intra-node NVLink for sharding while minimizing slow inter-node communication:
ZeRO-1: Model fits in memory with DDP, but optimizer states are too large.
Free memory savings, no extra communication.
ZeRO-2: Model fits, but gradients + optimizer don't. Small models with
large batch sizes. Still no extra communication!
ZeRO-3/FSDP: Model itself doesn't fit. Essential for 10B+ parameter models.
Accept 1.5× communication overhead.
HSDP: Multi-node training where inter-node bandwidth is limited.
Minimizes cross-node traffic while enabling large models.
9. Tensor Parallelism: Sharding Within Layers
While FSDP shards parameters between layers (gather before compute, scatter after), Tensor Parallelism (TP) shards parameters within individual layers. Each GPU computes a portion of every matrix multiplication, enabling parallelism at the finest granularity. This technique was pioneered by Megatron-LM for training massive transformer models.
9.1 The Core Idea: Matrix Partitioning
Consider a linear layer Y = XW where X is [B×K] and W is [K×N]. We can
partition this computation across GPUs in two fundamental ways:
9.2 Megatron-LM Self-Attention Parallelism
In transformer self-attention, we can parallelize across the attention heads. Each GPU computes a subset of heads independently, then results are combined:
9.3 Megatron-LM MLP Parallelism
The feedforward network (MLP) in transformers follows the same pattern:
9.4 Communication Analysis
Tensor parallelism has very different communication characteristics from data parallelism:
| Aspect | Data Parallel (DDP) | Tensor Parallel (TP) |
|---|---|---|
| Communication Pattern | AllReduce gradients (backward only) | AllReduce activations (forward + backward) |
| Frequency | Once per iteration | 2× per transformer layer (Attn + MLP) |
| Message Size | All gradients (~N parameters) | Activation tensor (B × S × H) |
| Bandwidth Need | Can overlap with backward | Critical path (compute blocked!) |
| Latency Sensitivity | Low (large messages) | High (many small messages) |
Since tensor parallelism has communication in the critical path of every forward and backward pass, it requires extremely high bandwidth between GPUs. This is why:
- TP within a node: Use NVLink (600-900 GB/s on H100)
- TP across nodes: Generally avoided! IB is too slow.
- Typical TP degree: 2, 4, or 8 (matching GPUs per node)
9.5 Communication Volume per Layer
For a transformer layer with tensor parallelism degree T:
Where the factor of 4 comes from: 2 AllReduces (attention + MLP) × 2 for forward and backward. The $\frac{2(T-1)}{T}$ term is the AllReduce cost in the ring algorithm.
def tensor_parallel_comm_per_layer(
batch_size: int,
seq_len: int,
hidden_dim: int,
tp_degree: int,
dtype_bytes: int = 2 # FP16
) -> float:
"""Calculate communication volume per transformer layer with TP."""
# Activation size
activation_size = batch_size * seq_len * hidden_dim * dtype_bytes
# AllReduce cost (ring algorithm)
allreduce_factor = 2 * (tp_degree - 1) / tp_degree
# 4 AllReduces per layer (attn fwd, attn bwd, mlp fwd, mlp bwd)
comm_volume = 4 * allreduce_factor * activation_size
return comm_volume
# Example: GPT-3 175B style layer
comm_bytes = tensor_parallel_comm_per_layer(
batch_size=1,
seq_len=2048,
hidden_dim=12288,
tp_degree=8
)
print(f"Comm per layer: {comm_bytes / 1e6:.1f} MB")
# Output: ~175 MB per layer
# With 96 layers: ~16.8 GB total TP communication per iteration!
9.6 Sequence Parallelism
A refinement to tensor parallelism is Sequence Parallelism (SP), which distributes the LayerNorm and Dropout operations across the sequence dimension. This reduces activation memory further:
9.7 Implementation: Megatron-Style TP
import torch
import torch.nn as nn
import torch.distributed as dist
class ColumnParallelLinear(nn.Module):
"""Linear layer with column-parallel weight distribution.
Input X is replicated, weight W is split along columns (output dim).
Output Y is split along columns (each GPU has different output features).
"""
def __init__(self, in_features, out_features, tp_group, gather_output=False):
super().__init__()
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group)
self.gather_output = gather_output
# Each GPU gets out_features/tp_size columns
assert out_features % self.tp_size == 0
self.out_features_per_partition = out_features // self.tp_size
self.weight = nn.Parameter(
torch.empty(self.out_features_per_partition, in_features)
)
self.bias = nn.Parameter(torch.empty(self.out_features_per_partition))
def forward(self, x):
# Local matmul: no communication!
output = F.linear(x, self.weight, self.bias)
if self.gather_output:
# Optionally gather to get full output
output = _gather_along_last_dim(output, self.tp_group)
return output
class RowParallelLinear(nn.Module):
"""Linear layer with row-parallel weight distribution.
Input X is split along last dim, weight W is split along rows (input dim).
Outputs are partial sums that need AllReduce.
"""
def __init__(self, in_features, out_features, tp_group, input_is_parallel=True):
super().__init__()
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group)
self.input_is_parallel = input_is_parallel
# Each GPU gets in_features/tp_size rows
assert in_features % self.tp_size == 0
self.in_features_per_partition = in_features // self.tp_size
self.weight = nn.Parameter(
torch.empty(out_features, self.in_features_per_partition)
)
self.bias = nn.Parameter(torch.empty(out_features))
def forward(self, x):
if not self.input_is_parallel:
# Split input if it's replicated
x = _split_along_last_dim(x, self.tp_group)
# Local matmul (partial result)
output_partial = F.linear(x, self.weight)
# AllReduce to sum partial results
dist.all_reduce(output_partial, group=self.tp_group)
# Add bias after reduction (only on one rank to avoid duplication)
if dist.get_rank(self.tp_group) == 0:
output_partial = output_partial + self.bias
return output_partial
class TensorParallelMLP(nn.Module):
"""Megatron-style tensor parallel MLP.
Column-parallel first layer, row-parallel second layer.
Only ONE AllReduce for the entire MLP!
"""
def __init__(self, hidden_size, ffn_hidden_size, tp_group):
super().__init__()
# Column-parallel: split FFN hidden dim
self.fc1 = ColumnParallelLinear(
hidden_size, ffn_hidden_size, tp_group, gather_output=False
)
# Row-parallel: recombine to hidden_size
self.fc2 = RowParallelLinear(
ffn_hidden_size, hidden_size, tp_group, input_is_parallel=True
)
def forward(self, x):
# x: [B, S, H] - replicated across TP ranks
x = self.fc1(x) # [B, S, FFN/TP] - split output
x = F.gelu(x) # Local activation
x = self.fc2(x) # [B, S, H] - AllReduce inside
return x
9.8 When to Use Tensor Parallelism
Use Tensor Parallelism when:
- Single layers are too large for GPU memory (very wide models)
- You have fast intra-node interconnect (NVLink, NVSwitch)
- Model has many attention heads (divisible by TP degree)
- Latency per step matters more than throughput
Avoid Tensor Parallelism when:
- Training across nodes without high-speed interconnect
- Small models where FSDP/DDP suffice
- TP degree doesn't divide attention heads evenly
Best Practices:
- TP=8 for 8-GPU nodes (H100, A100 DGX)
- Combine with Pipeline Parallelism across nodes
- Use Sequence Parallelism to reduce activation memory
10. Pipeline Parallelism: Partitioning Across Layers
Pipeline Parallelism (PP) partitions the model vertically by assigning different layers to different GPUs. Unlike tensor parallelism (which splits within layers), pipeline parallelism keeps each layer intact but distributes them across devices. This approach is ideal for spanning multiple nodes with slower inter-node connections.
10.1 The Naive Approach and the Bubble Problem
The simplest pipeline would process one micro-batch at a time through all stages. But this creates massive pipeline bubbles where most GPUs sit idle:
With P pipeline stages and 1 micro-batch, each GPU is active only 2/P of the time (one forward, one backward). The rest is bubble time:
10.2 GPipe: Micro-Batching to Reduce Bubbles
GPipe (Google, 2019) introduced micro-batching: split each mini-batch into M smaller micro-batches and pipeline them through the stages. This fills the pipeline and dramatically reduces bubble fraction:
GPipe reduces bubbles but has a critical problem: activations from all M micro-batches must be stored until the backward pass begins. Memory scales as O(M × activations_per_microbatch).
10.3 1F1B: Interleaving Forward and Backward
The 1F1B (One Forward One Backward) schedule from PipeDream solves the memory problem. After the warmup phase, each stage alternates between one forward and one backward, maintaining a constant number of in-flight micro-batches:
10.4 Interleaved 1F1B (Virtual Stages)
Interleaved 1F1B further reduces bubbles by assigning multiple non-consecutive "virtual stages" to each GPU. With V virtual stages per GPU:
10.5 Communication Patterns in Pipeline Parallelism
Unlike data parallelism (AllReduce) or tensor parallelism (AllReduce), pipeline parallelism uses point-to-point (P2P) communication:
| Communication | Direction | Data | Size |
|---|---|---|---|
| Forward Activations | Stage i → Stage i+1 | Hidden states | B × S × H bytes |
| Backward Gradients | Stage i+1 → Stage i | Activation gradients | B × S × H bytes |
10.6 Pipeline Communication Volume
Total PP communication per micro-batch per stage boundary:
(Factor of 2: one forward send, one backward send)
def pipeline_comm_per_iteration(
batch_size: int,
seq_len: int,
hidden_dim: int,
num_micro_batches: int,
num_stages: int,
dtype_bytes: int = 2
) -> dict:
"""Calculate pipeline parallel communication."""
micro_batch_size = batch_size // num_micro_batches
# Activation size per micro-batch
act_size = micro_batch_size * seq_len * hidden_dim * dtype_bytes
# P2P comm per stage boundary
# Forward: send activations
# Backward: send gradients
p2p_per_boundary = 2 * act_size # fwd + bwd
# Total boundaries = num_stages - 1
# Total micro-batches = num_micro_batches
total_p2p = p2p_per_boundary * (num_stages - 1) * num_micro_batches
# Bubble analysis
bubble_fraction = (num_stages - 1) / (num_micro_batches + num_stages - 1)
return {
"p2p_per_microbatch_boundary": p2p_per_boundary,
"total_p2p_bytes": total_p2p,
"bubble_fraction": bubble_fraction,
"efficiency": 1 - bubble_fraction
}
# Example: 8 stages, 32 micro-batches
result = pipeline_comm_per_iteration(
batch_size=512,
seq_len=2048,
hidden_dim=8192,
num_micro_batches=32,
num_stages=8
)
print(f"Efficiency: {result['efficiency']*100:.1f}%") # ~82%
print(f"Total P2P: {result['total_p2p_bytes']/1e9:.1f} GB")
10.7 Implementation with PyTorch
import torch
import torch.distributed as dist
class PipelineStage:
"""Simple 1F1B pipeline stage implementation."""
def __init__(self, module, stage_id, num_stages, process_group):
self.module = module
self.stage_id = stage_id
self.num_stages = num_stages
self.group = process_group
# Determine neighbors
self.prev_rank = stage_id - 1 if stage_id > 0 else None
self.next_rank = stage_id + 1 if stage_id < num_stages - 1 else None
def recv_forward(self, shape, dtype):
"""Receive activations from previous stage."""
if self.prev_rank is None:
return None
tensor = torch.empty(shape, dtype=dtype, device='cuda')
dist.recv(tensor, src=self.prev_rank, group=self.group)
return tensor.requires_grad_()
def send_forward(self, tensor):
"""Send activations to next stage."""
if self.next_rank is None:
return
dist.send(tensor, dst=self.next_rank, group=self.group)
def recv_backward(self, shape, dtype):
"""Receive gradients from next stage."""
if self.next_rank is None:
return None
tensor = torch.empty(shape, dtype=dtype, device='cuda')
dist.recv(tensor, src=self.next_rank, group=self.group)
return tensor
def send_backward(self, tensor):
"""Send gradients to previous stage."""
if self.prev_rank is None:
return
dist.send(tensor, dst=self.prev_rank, group=self.group)
def forward_step(self, input_tensor):
"""Execute one forward micro-batch."""
output = self.module(input_tensor)
self.send_forward(output.detach())
return output
def backward_step(self, output_tensor, output_grad):
"""Execute one backward micro-batch."""
torch.autograd.backward(output_tensor, output_grad)
if output_tensor.grad is not None:
self.send_backward(output_tensor.grad)
def schedule_1f1b(stage, num_micro_batches, input_tensors):
"""1F1B schedule for one pipeline stage."""
num_warmup = stage.num_stages - stage.stage_id - 1
num_microbatches_remaining = num_micro_batches - num_warmup
# Store outputs for backward
outputs = []
# Warmup: only forward passes
for i in range(num_warmup):
inp = stage.recv_forward(...) if stage.prev_rank else input_tensors[i]
out = stage.forward_step(inp)
outputs.append((inp, out))
# Steady state: 1F1B
for i in range(num_microbatches_remaining):
# Forward
inp = stage.recv_forward(...) if stage.prev_rank else input_tensors[num_warmup + i]
out = stage.forward_step(inp)
outputs.append((inp, out))
# Backward (for micro-batch i)
old_inp, old_out = outputs.pop(0)
grad = stage.recv_backward(...) if stage.next_rank else None
stage.backward_step(old_out, grad)
# Cooldown: only backward passes
for _ in range(num_warmup):
old_inp, old_out = outputs.pop(0)
grad = stage.recv_backward(...) if stage.next_rank else None
stage.backward_step(old_out, grad)
10.8 When to Use Pipeline Parallelism
Ideal for:
- Multi-node training where inter-node bandwidth is limited
- Very deep models (many layers to distribute)
- When combined with TP within nodes (3D parallelism)
- Large batch training where many micro-batches hide bubbles
Challenges:
- Pipeline bubbles reduce efficiency (need M >> P)
- Complex scheduling logic
- Load balancing across stages (unequal layer costs)
- Memory imbalance (first/last stages often need more)
Best Practices:
- Use micro-batches M ≥ 4×P for reasonable efficiency
- Consider interleaved schedules (V > 1) for lower bubbles
- Place PP across nodes, TP within nodes
- Balance layers to equalize compute per stage
11. 3D Parallelism and Mixture of Experts
Training the largest models (100B+ parameters) requires combining multiple parallelism strategies. 3D Parallelism combines Data, Tensor, and Pipeline parallelism to scale across thousands of GPUs. Mixture of Experts (MoE) adds another dimension by dynamically routing tokens to specialized sub-networks.
11.1 The 3D Parallelism Framework
Each parallelism strategy has different strengths and communication patterns. 3D parallelism assigns each dimension to the appropriate network topology:
11.2 Process Group Configuration
In 3D parallelism, GPUs are organized into process groups for each parallelism dimension. A GPU participates in three groups simultaneously:
import torch.distributed as dist
def initialize_3d_parallelism(
tp_size: int,
pp_size: int,
dp_size: int
):
"""Initialize process groups for 3D parallelism.
GPU layout (example: TP=2, PP=2, DP=2 = 8 GPUs):
Node 0: GPU 0,1 (TP group 0, PP stage 0)
Node 1: GPU 2,3 (TP group 1, PP stage 1)
Node 2: GPU 4,5 (TP group 2, PP stage 0) # DP replica
Node 3: GPU 6,7 (TP group 3, PP stage 1) # DP replica
"""
world_size = tp_size * pp_size * dp_size
rank = dist.get_rank()
# Calculate coordinates in 3D grid
tp_rank = rank % tp_size
pp_rank = (rank // tp_size) % pp_size
dp_rank = rank // (tp_size * pp_size)
# Create Tensor Parallel groups
# GPUs in same TP group: same pp_rank, same dp_rank
tp_groups = []
for dp in range(dp_size):
for pp in range(pp_size):
ranks = [
dp * (tp_size * pp_size) + pp * tp_size + tp
for tp in range(tp_size)
]
group = dist.new_group(ranks)
if rank in ranks:
tp_group = group
# Create Pipeline Parallel groups
# GPUs in same PP group: same tp_rank, same dp_rank
pp_groups = []
for dp in range(dp_size):
for tp in range(tp_size):
ranks = [
dp * (tp_size * pp_size) + pp * tp_size + tp
for pp in range(pp_size)
]
group = dist.new_group(ranks)
if rank in ranks:
pp_group = group
# Create Data Parallel groups
# GPUs in same DP group: same tp_rank, same pp_rank
for pp in range(pp_size):
for tp in range(tp_size):
ranks = [
dp * (tp_size * pp_size) + pp * tp_size + tp
for dp in range(dp_size)
]
group = dist.new_group(ranks)
if rank in ranks:
dp_group = group
return {
'tp_group': tp_group,
'pp_group': pp_group,
'dp_group': dp_group,
'tp_rank': tp_rank,
'pp_rank': pp_rank,
'dp_rank': dp_rank,
}
# Example: 128 GPUs with TP=8, PP=4, DP=4
groups = initialize_3d_parallelism(tp_size=8, pp_size=4, dp_size=4)
print(f"GPU {dist.get_rank()}: TP={groups['tp_rank']}, PP={groups['pp_rank']}, DP={groups['dp_rank']}")
11.3 Communication Analysis for 3D Parallelism
Let's analyze the total communication volume for one training iteration:
| Dimension | Operation | Volume per Iteration | When |
|---|---|---|---|
| TP | AllReduce activations | 4 × L × B × S × H × 2(T-1)/T | Every layer fwd+bwd |
| PP | P2P send/recv | 2 × M × (P-1) × B/M × S × H | Per micro-batch boundary |
| DP | AllReduce gradients | 2 × N/P × 2(D-1)/D | Once at end (overlapped) |
In well-optimized 3D parallelism:
- TP communication is on the critical path (cannot hide)
- PP communication is partially hidden during micro-batch processing
- DP communication can be fully overlapped with backward pass
This is why TP needs NVLink, PP can use IB, and DP can tolerate the slowest links.
11.4 Mixture of Experts (MoE)
Mixture of Experts is an orthogonal scaling technique that increases model capacity without proportionally increasing compute. Each token is routed to only a subset of "expert" sub-networks:
11.5 Expert Parallelism Communication
MoE introduces a unique communication pattern: All-to-All to route tokens to their assigned experts, then All-to-All again to return results:
11.6 Expert Parallel vs Other Dimensions
| Aspect | Tensor Parallel | Expert Parallel |
|---|---|---|
| What's distributed | Single layer's parameters | Different expert FFNs |
| Operation | AllReduce (sum partials) | All-to-All (route tokens) |
| Message pattern | Everyone sends to everyone (same data) | Everyone sends different data to everyone |
| Compute | All GPUs compute every token | Each GPU computes subset of tokens |
| Load balance | Perfect (same compute) | Can be imbalanced (routing-dependent) |
11.7 Load Balancing in MoE
A critical challenge in MoE is load balancing. If most tokens route to the same experts, some GPUs are overloaded while others sit idle:
import torch
import torch.nn.functional as F
def load_balancing_loss(router_logits, num_experts, top_k=2):
"""Auxiliary loss to encourage balanced expert usage.
From Switch Transformer / GShard papers.
"""
# router_logits: [batch_size * seq_len, num_experts]
# Get routing probabilities
routing_probs = F.softmax(router_logits, dim=-1)
# Get top-k expert assignments
_, top_k_indices = torch.topk(router_logits, k=top_k, dim=-1)
# Create one-hot masks for selected experts
expert_mask = torch.zeros_like(routing_probs)
expert_mask.scatter_(-1, top_k_indices, 1.0)
# Fraction of tokens routed to each expert
tokens_per_expert = expert_mask.mean(dim=0) # [num_experts]
# Average routing probability per expert
router_prob_per_expert = routing_probs.mean(dim=0) # [num_experts]
# Load balancing loss: encourage uniform distribution
# Minimize product of (fraction routed) × (avg probability)
aux_loss = num_experts * (tokens_per_expert * router_prob_per_expert).sum()
return aux_loss
class MoELayer(torch.nn.Module):
def __init__(self, hidden_size, ffn_size, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router network
self.router = torch.nn.Linear(hidden_size, num_experts, bias=False)
# Expert FFN networks
self.experts = torch.nn.ModuleList([
torch.nn.Sequential(
torch.nn.Linear(hidden_size, ffn_size),
torch.nn.GELU(),
torch.nn.Linear(ffn_size, hidden_size)
)
for _ in range(num_experts)
])
def forward(self, x):
batch_seq, hidden = x.shape
# Compute routing scores
router_logits = self.router(x) # [B*S, E]
# Get top-k experts and weights
routing_weights, selected_experts = torch.topk(
F.softmax(router_logits, dim=-1),
self.top_k,
dim=-1
)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
# Compute expert outputs (simplified - real impl uses All-to-All)
output = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
# Find tokens routed to this expert
expert_mask = (selected_experts == i).any(dim=-1)
if expert_mask.any():
expert_input = x[expert_mask]
expert_output = expert(expert_input)
# Weight by routing probability
weights = routing_weights[expert_mask]
weight_mask = (selected_experts[expert_mask] == i)
expert_weight = (weights * weight_mask).sum(dim=-1, keepdim=True)
output[expert_mask] += expert_weight * expert_output
# Compute auxiliary loss for load balancing
self.aux_loss = load_balancing_loss(router_logits, self.num_experts, self.top_k)
return output
11.8 Combining MoE with 3D Parallelism
Modern systems like Megatron-LM and DeepSpeed combine MoE with 3D parallelism, creating 4D+ parallelism:
Advantages:
- Parameter efficiency: 8× more parameters with only 2× compute (top-2 routing)
- Specialization: Experts can learn different aspects of data
- Scalability: Add experts without increasing per-token compute
Challenges:
- All-to-All communication: Expensive on multi-node systems
- Load imbalance: Requires auxiliary losses and capacity factors
- Training instability: Experts can "collapse" if not careful
- Inference complexity: Batching across experts is tricky
12. Case Studies: How the Giants Train
Let's examine how state-of-the-art models configure their distributed training systems. These case studies illustrate how the parallelism strategies we've covered combine in practice.
12.1 GPT-3 (175B Parameters)
OpenAI's GPT-3 (2020) was trained on a cluster of V100 GPUs using a combination of model and data parallelism:
12.2 LLaMA / LLaMA-2 (7B - 70B)
Meta's LLaMA models demonstrate efficient training with modern optimizations:
| Model | Parameters | TP | PP | DP | GPUs | Training Time |
|---|---|---|---|---|---|---|
| LLaMA-7B | 6.7B | 1 | 1 | Variable | ~1,000 A100 | ~21 days |
| LLaMA-13B | 13B | 2 | 1 | Variable | ~1,000 A100 | ~29 days |
| LLaMA-33B | 32.5B | 4 | 1 | Variable | ~1,000 A100 | ~53 days |
| LLaMA-65B | 65.2B | 8 | 1 | Variable | ~2,000 A100 | ~21 days |
| LLaMA-2 70B | 70B | 8 | 1 | Variable | ~2,000 A100 | ~1.7M GPU-hrs |
- No PP needed for most variants: Models fit with TP + FSDP
- Grouped-Query Attention (GQA): Reduces KV cache memory in LLaMA-2
- RMSNorm + SwiGLU: Efficient normalizations and activations
- Rotary embeddings: No absolute position embeddings to store
12.3 Mixtral 8x7B (MoE)
Mistral's Mixtral demonstrates efficient MoE training with expert parallelism:
12.4 DeepSeek-V2 (236B MoE)
DeepSeek's V2 model showcases advanced MoE techniques with innovative communication optimization:
# DeepSeek-V2 236B Configuration
model:
total_params: 236B
active_params: 21B # Per token
architecture:
layers: 60
hidden_dim: 5120
num_experts: 160
experts_per_token: 6 # Top-6 routing
shared_experts: 2 # Always-active experts
innovations:
- Multi-head Latent Attention (MLA) # Reduces KV cache
- DeepSeekMoE with shared experts
- Fine-grained expert segmentation
training:
parallelism:
expert_parallel: 8-16
tensor_parallel: 1-4
pipeline_parallel: 1
data_parallel: FSDP
hardware:
gpus: ~2000 H800
interconnect: NVLink + IB
efficiency:
tokens_trained: 8.1T
training_cost: ~$5.5M # Remarkably efficient!
12.5 Practical Decision Framework
Use this decision tree to choose your parallelism strategy:
12.6 Summary: Communication Patterns at a Glance
| Strategy | Collective | When | Volume | Overlap? |
|---|---|---|---|---|
| DDP | AllReduce (gradients) | End of backward | 2N | ✓ Yes |
| FSDP/ZeRO-3 | AllGather + ReduceScatter | Every layer | 3N | Partial |
| Tensor Parallel | AllReduce (activations) | Every layer fwd+bwd | 4×L×BSH | ✗ Critical path |
| Pipeline Parallel | P2P Send/Recv | Stage boundaries | 2×M×BSH | Partial |
| Expert Parallel | All-to-All | MoE layers | 4×BSH | ✗ Critical path |
13. Conclusion
Distributed training has evolved from simple data parallelism to sophisticated multi-dimensional strategies that enable training models with hundreds of billions of parameters. The key insights from this guide:
1. Understand your collectives:
- AllReduce, AllGather, ReduceScatter, and All-to-All are the building blocks
- Ring and tree algorithms trade off latency vs bandwidth
- The α-β model helps predict communication costs
2. Match parallelism to your hardware:
- TP within nodes (NVLink), PP across nodes (IB), DP anywhere
- Communication hierarchy determines efficiency
- Roofline analysis identifies bottlenecks
3. Scale systematically:
- Start with DDP/FSDP for models that fit
- Add TP when layers are too large
- Add PP when model spans nodes
- Consider MoE for parameter-efficient scaling
4. Optimize relentlessly:
- Overlap communication with computation
- Use mixed precision and activation checkpointing
- Balance load across pipeline stages and experts
- Profile, measure, iterate
The field continues to evolve rapidly. New techniques like sequence parallelism, context parallelism for long sequences, and novel MoE routing strategies push the boundaries of what's possible. Understanding the fundamentals covered here provides the foundation to adapt to these advances.
Further Reading
- Megatron-LM: Training Multi-Billion Parameter Models Using Model Parallelism
- ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- Efficient Large-Scale Language Model Training on GPU Clusters
- Reducing Activation Recomputation in Large Transformer Models
- PyTorch FSDP Documentation
- DeepSpeed Library and Documentation
🚀 Now go train some massive models!
Remember: The best parallelism strategy is the one that maximizes throughput while staying within your memory budget.