Motivation
In the previous article, we did some back-of-the-envelope arithmetic: a 70B parameter model requires $16 \times 70 \times 10^9 \approx 1120$ GB just for Adam optimizer-related memory — at least 14 A100 80GB GPUs, and that is before we even count activations. Training a large model on a single GPU is physically impossible. We must distribute computation across multiple devices.
But “distribute across GPUs” is far easier said than done. Different parallelism strategies make fundamentally different choices about what to split (parameters, gradients, activations, sequences), how to split it (by layer, by dimension, by data), and what communication cost to pay (all_reduce, all_gather, send/recv, all_to_all). Choose the wrong strategy and you waste compute at best, or fail to run at all at worst.
This article starts from the simplest strategy — DDP — and progressively covers FSDP, Tensor Parallelism, Pipeline Parallelism, Sequence Parallelism, Expert Parallelism, and the cutting-edge approaches of Context Parallelism and hybrid parallelism. By the end, you should be able to look at any model size, hardware topology, and training requirement, and assemble a sensible combination of parallelism strategies.
Prerequisites
- GPU Memory Model and Distributed Communication Fundamentals (Article 1) — especially NCCL primitives and hardware topology
- Basic PyTorch DDP experience (you have written a training script launched with
torchrun) - Familiarity with the Transformer architecture: Self-Attention, FFN, LayerNorm
Here is the big picture. We will unpack each piece in the sections that follow:
flowchart TD
subgraph DP["Data Parallel"]
DDP["DDP"]
FSDP["FSDP / ZeRO"]
end
subgraph MP["Model Parallel"]
TP["TP (Tensor)"]
SP["SP (Sequence)"]
end
subgraph PP["Pipeline Parallel"]
PP1["1F1B"]
PP2["Zero Bubble"]
end
subgraph EP["Expert Parallel"]
EP1["EP (MoE)"]
EP2["DeepSeek"]
end
DP & MP & PP & EP --> HYBRID["Hybrid Parallelism (3D/5D)
TP intra-node + FSDP cross-node + PP cross-node-group"]
Classic Parallelism Strategies
DDP (Distributed Data Parallel)
DDP is the simplest and most widely used parallelism strategy. The core idea fits in three words: replicate model, split data.
How it works:
- Every GPU holds a complete copy of the model — parameters, gradients, and optimizer state all fully replicated.
- Training data is partitioned across GPUs via
DistributedSampler— each GPU sees only its own mini-batch. - Each GPU independently performs forward and backward passes, computing its own local gradients.
- An
all_reduceaggregates (sums or averages) the gradients so that every GPU ends up with identical aggregated gradients. - Each GPU executes the same optimizer update, keeping parameters in sync.
flowchart TD
DATA["Data Batch"] -->|DistributedSampler| B0["B₀"] & B1["B₁"] & B2["B₂"] & B3["B₃"]
B0 --> G0["GPU 0
Full model
→ grad₀"]
B1 --> G1["GPU 1
Full model
→ grad₁"]
B2 --> G2["GPU 2
Full model
→ grad₂"]
B3 --> G3["GPU 3
Full model
→ grad₃"]
G0 & G1 & G2 & G3 --> AR["all_reduce(gradients)
Only communication op"]
AR --> OPT["optimizer.step()
Identical update on all GPUs"]
Memory. DDP does not save any memory. Each GPU independently stores the full parameters ($2\Phi$ bytes in FP16), gradients ($2\Phi$ bytes), and optimizer state ($12\Phi$ bytes for Adam), totaling $16\Phi$ bytes per GPU — exactly the same as single-GPU training.
Communication. The only communication is an all_reduce on gradients. PyTorch DDP uses an important optimization called Bucketed All-Reduce: instead of waiting for all gradients to be computed, it groups gradients into buckets (default 25 MB). As soon as all gradients in a bucket are ready, their all_reduce begins — overlapping with the backward pass of earlier layers. This significantly hides communication latency.
When to use DDP. Whenever the model fits on a single GPU. DDP is simple, efficient, and scales near-linearly. Its limitation is equally clear: the model must fit entirely in a single GPU’s memory.
Here is the core code:
| |
Full code:
code/02-parallel-strategies/ddp_example.py
FSDP / FSDP2 (Fully Sharded Data Parallel)
The fatal limitation of DDP is that every GPU stores a complete copy of parameters, gradients, and optimizer state. When the model is too large for a single GPU, DDP is out of options.
FSDP (Fully Sharded Data Parallel) draws its core idea from the DeepSpeed ZeRO (Zero Redundancy Optimizer) series of papers: since every GPU holds redundant copies of parameters and optimizer state, why not shard them across GPUs and reconstruct them temporarily when needed?
ZeRO Stages 1, 2, and 3
ZeRO organizes memory optimization into three stages, each sharding progressively more:
| ZeRO Stage | What is Sharded | Per-GPU Memory | FSDP Equivalent |
|---|---|---|---|
| Stage 1 | Optimizer state | $4\Phi + \frac{12\Phi}{N}$ | — |
| Stage 2 | Optimizer state + gradients | $2\Phi + \frac{14\Phi}{N}$ | SHARD_GRAD_OP |
| Stage 3 | Optimizer state + gradients + parameters | $\frac{16\Phi}{N}$ | FULL_SHARD |
Here $\Phi$ is the parameter count and $N$ is the number of GPUs. The key observation is that Stage 3 achieves near-linear memory reduction — 8 GPUs cut per-GPU memory to roughly 1/8th (plus activation overhead).
FSDP Forward and Backward Passes
Here is how FSDP (FULL_SHARD) works in practice:
flowchart LR
subgraph FWD["FSDP Forward Pass — Layer i"]
direction LR
FS1["Shard
1/N"] -->|"all_gather"| FP1["Full Params
(temporary)"]
FP1 -->|"compute"| FO1["Output"]
FP1 -.->|"discard (N-1)/N"| X1[" "]
end
subgraph BWD["FSDP Backward Pass — Layer i"]
direction LR
BS1["Shard
1/N"] -->|"all_gather"| BP1["Full W
(temporary)"]
BP1 -->|"backward"| BG1["Grad
(full)"]
BG1 -->|"reduce_scatter"| BGS["Grad Shard
1/N"]
end
FWD ~~~ BWD
style X1 fill:none,stroke:none

Communication cost comparison:
| Strategy | Forward Communication | Backward Communication | Total Volume |
|---|---|---|---|
| DDP | None | 1x all_reduce (gradients) | $2\Phi$ |
| FSDP (FULL_SHARD) | all_gather (parameters) | all_gather (parameters) + reduce_scatter (gradients) | $3\Phi$ |
FSDP communicates roughly 1.5x as much as DDP — that is the cost of trading communication for memory. But in the large-model regime, this trade-off is absolutely worth it: without FSDP, you simply cannot run the model at all.
FSDP2: The Composable API in PyTorch 2.2+
PyTorch 2.2 and later introduced FSDP2 (torch.distributed._composable.fsdp), which brings several improvements over the original FSDP:
- Per-parameter sharding: instead of requiring entire
nn.Modulesubtrees as FSDP units, FSDP2 can shard individual parameters. - Composability: FSDP2 composes freely with Tensor Parallel, Pipeline Parallel, and other strategies without the awkward nested wrapping that FSDP1 required.
- Flexible sharding granularity: different layers can use different sharding strategies.
The core concepts are identical to FSDP1 — the all_gather/reduce_scatter communication pattern is unchanged. What is new is a more modern API that integrates better with the PyTorch 2.x compiler stack.
Here is a code snippet showing how to compare FSDP sharding strategies:
| |
Full code:
code/02-parallel-strategies/fsdp_example.py
TP (Tensor Parallel) — Megatron Column/Row Partitioning
FSDP shards parameters, then reassembles them — communication and computation are sequential. Tensor Parallel takes a more radical approach: it slices each layer’s weight matrix along a dimension and distributes the pieces to different GPUs, each of which computes part of the result. This directly reduces per-GPU compute and memory.
The core ideas come from Megatron-LM, which defines two fundamental parallel linear layer types.
ColumnParallelLinear: Splitting Along the Output Dimension
For a linear layer $Y = XW + b$, we partition the weight $W \in \mathbb{R}^{d \times h}$ by columns:
Full weight W: (d_model × dim_ffn)
Split by columns across 2 GPUs:
GPU 0: W_0 = W[:, :dim_ffn//2] → Y_0 = X @ W_0
GPU 1: W_1 = W[:, dim_ffn//2:] → Y_1 = X @ W_1
Each GPU computes independently — no communication needed!
(because X is identical on all GPUs)
Key advantage: the forward pass requires zero communication. Each GPU produces a chunk of the output, which can be fed directly into elementwise operations (such as GeLU) without any synchronization.
RowParallelLinear: Splitting Along the Input Dimension
We partition the weight $W \in \mathbb{R}^{h \times d}$ by rows:
Split by rows across 2 GPUs:
GPU 0: W_0 = W[:dim_ffn//2, :] → Y_0 = X_0 @ W_0 (partial sum)
GPU 1: W_1 = W[dim_ffn//2:, :] → Y_1 = X_1 @ W_1 (partial sum)
Final result: Y = Y_0 + Y_1 ← requires all_reduce!
Key constraint: each GPU only computes a partial sum of the output. An all_reduce is required to produce the complete result.
Megatron FFN: Column + Row = Only 1 All-Reduce
The elegance of Megatron-LM lies in pairing Column and Row parallel layers:
flowchart TD
X["Input X
(identical on each GPU)"]
X --> COL["ColumnParallelLinear
(W1 column-split, no comm)"]
COL --> GELU["GeLU
(elementwise, no comm)"]
GELU --> ROW["RowParallelLinear
(W2 row-split)"]
ROW -->|"all_reduce
merge partial sums"| Y["Output Y
(identical on each GPU)"]
style COL fill:#d4edda
style GELU fill:#d4edda
style ROW fill:#fff3cd
Only 1 all_reduce for the entire FFN block.

Tensor Parallel for Attention
Multi-Head Attention maps to TP equally elegantly:
- Q, K, V projections: use ColumnParallel, distributing attention heads across GPUs — each GPU handles $\frac{n_heads}{TP}$ heads.
- Output projection: use RowParallel, merging head outputs across GPUs.
This means the entire Attention block also needs only 1 all_reduce. A full Transformer layer therefore requires 2 all_reduces total (1 for FFN + 1 for Attention).
When TP Works Well (and When It Does Not)
TP communicates frequently — 2 all_reduces per layer — so it demands very high bandwidth. In practice:
- TP degree is typically 2, 4, or 8, deployed within a single node (NVLink at 600 GB/s).
- Running TP across nodes is almost never viable (InfiniBand at 25-50 GB/s is too slow for per-layer all_reduce).
- The hidden dimension and head count must be divisible by the TP degree.
Here is the core implementation of ColumnParallelLinear:
| |
PP (Pipeline Parallel) — 1F1B and Zero Bubble
Where TP splits individual layers by dimension, Pipeline Parallel takes the opposite approach: it splits the model by layers, assigning different groups of layers to different GPUs (called stages).
flowchart LR
S0["Stage 0 (GPU 0)\nLayer 0-10"] -->|"send/recv\nactivations"| S1["Stage 1 (GPU 1)\nLayer 11-21"]
S1 -->|"send/recv\nactivations"| S2["Stage 2 (GPU 2)\nLayer 22-31"]
Communication:
send/recv(point-to-point, adjacent stages only).
Communication. PP only needs send/recv between adjacent stages — transferring intermediate activations (forward) and gradients (backward). The communication volume is far smaller than TP’s all_reduce, making PP well-suited for cross-node deployment.
But PP has a critical problem: pipeline bubbles.
Naive PP: Massive Bubbles
In the simplest scheme, the entire batch passes through stages sequentially. While one stage computes, all others sit idle:
Time →
GPU 0: [=====Forward=====] [=====Backward=====]
GPU 1: [=====Forward=====] [=====Backward=====]
GPU 2: [=====Forward=====] [=====Backward=====]
↑
Massive idle time (bubble)
With $P$ stages, the bubble fraction is approximately $\frac{P-1}{P}$ — 4 stages means 75% of the time is wasted.
GPipe: Micro-Batching
GPipe’s solution is to split each mini-batch into $M$ micro-batches, letting them flow through stages like an assembly line:
Time → (M = 4 micro-batches)
GPU 0: [F0][F1][F2][F3] [B3][B2][B1][B0]
GPU 1: [F0][F1][F2][F3] [B3][B2][B1][B0]
GPU 2: [F0][F1][F2][F3] [B3][B2][B1][B0]
GPU 3: [F0][F1][F2][F3][B3][B2][B1][B0]
Bubble fraction: (P-1) / (M + P - 1)
When M >> P, bubble → 0
The bubble fraction drops from $\frac{P-1}{P}$ to $\frac{P-1}{M+P-1}$. However, GPipe completes all forward micro-batches before starting any backward passes, which means it must store activations for all $M$ micro-batches simultaneously — a significant memory cost.
1F1B Schedule
1F1B (One Forward, One Backward) interleaves forward and backward passes. After completing a micro-batch’s forward pass, it starts the backward pass as soon as possible, allowing activation memory to be freed earlier:
Time → (P=4, M=8)
GPU 0: [F0][F1][F2][F3][B0][F4][B1][F5][B2][F6][B3][F7][B4][B5][B6][B7]
GPU 1: [F0][F1][F2][B0][F3][B1][F4][B2][F5][B3][F6][B4][F7][B5][B6][B7]
GPU 2: [F0][F1][B0][F2][B1][F3][B2][F4][B3][F5][B4][F6][B5][F7]...
GPU 3: [F0][B0][F1][B1][F2][B2][F3][B3][F4][B4][F5][B5]...
During the steady state (after the warmup phase), each GPU simultaneously holds activations for only a bounded number of micro-batches, dramatically reducing memory compared to GPipe. The bubble ratio remains $\frac{P-1}{M+P-1}$, but memory usage improves substantially.
Zero Bubble PP
Proposed in 2024, Zero Bubble PP further reduces pipeline bubbles. The key insight is to decompose the backward pass into two parts: computing input gradients (B) and computing weight gradients (W). B must be communicated to the previous stage, but W does not require communication and can be scheduled to fill the bubble:
Time →
GPU 0: [F][F][F][B][F][B][W][B][W][B][W] ← W fills what used to be bubble
GPU 1: [F][F][B][F][B][W][B][W][B][W]
Zero Bubble PP can theoretically reduce the bubble fraction to near zero, at the cost of increased scheduling complexity.

PP summary:
| Schedule | Bubble Fraction | Activation Memory | Implementation Complexity |
|---|---|---|---|
| Naive | $(P-1)/P$ | Low | Low |
| GPipe | $(P-1)/(M+P-1)$ | High (all micro-batches) | Medium |
| 1F1B | $(P-1)/(M+P-1)$ | Medium (bounded micro-batches) | Medium |
| Zero Bubble | $\approx 0$ | Medium | High |
SP (Sequence Parallel) — LayerNorm/Dropout on the Sequence Dimension
Tensor Parallel splits the weight matrices of linear layers by dimension, but Transformers also contain many elementwise operations that have no weight matrices: LayerNorm, Dropout, and residual connections. Under TP, these operations need to operate on the full hidden dimension, which means every GPU must hold the complete activation tensor — creating redundant activation memory.
Sequence Parallel (SP) addresses this by splitting the activations of these non-TP operations along the sequence dimension instead:
flowchart TD
LN1["LayerNorm — SP: seq/N"]
AG1["all_gather: full sequence"]
ATT["Attention — TP: split by head"]
RS1["reduce_scatter: split to seq dim"]
DR1["Dropout + Residual — SP: seq/N"]
LN2["LayerNorm — SP"]
AG2["all_gather: full sequence"]
FFN["FFN — TP: Column + Row"]
RS2["reduce_scatter: back to seq dim"]
DR2["Dropout + Residual — SP"]
LN1 --> AG1 --> ATT --> RS1 --> DR1 --> LN2 --> AG2 --> FFN --> RS2 --> DR2
style LN1 fill:#d4edda
style DR1 fill:#d4edda
style LN2 fill:#d4edda
style DR2 fill:#d4edda
style ATT fill:#fff3cd
style FFN fill:#fff3cd
style AG1 fill:#cce5ff
style AG2 fill:#cce5ff
style RS1 fill:#cce5ff
style RS2 fill:#cce5ff
The key insight: TP’s all_reduce is decomposed into reduce_scatter + all_gather, placed at the exit and entrance of the TP regions respectively. The communication volume is unchanged (recall that all_reduce = reduce_scatter + all_gather), but activation memory in the SP regions drops to $1/N$.
SP’s value is most pronounced with large batches and long sequences, where activations dominate memory usage. In these cases, SP directly divides the activation cost by the TP degree.
MoE Parallelism
EP (Expert Parallel)
Mixture-of-Experts (MoE) introduces a set of “expert” subnetworks into the model, where each token activates only $k$ of them (typically $k=1$ or $2$). This allows MoE to dramatically increase parameter count while keeping compute roughly constant — but it also creates unique parallelism challenges.
The structure of a MoE layer:
flowchart TD
INPUT["Input tokens"]
ROUTER["Router
Decides which expert each token goes to"]
E0["Expert 0
FFN"] & E1["Expert 1
FFN"] & E2["Expert 2
FFN"] & E3["Expert 3
FFN ..."]
COMBINE["Combine outputs"]
INPUT --> ROUTER
ROUTER --> E0 & E1 & E2 & E3
E0 & E1 & E2 & E3 --> COMBINE
Expert Parallel (EP) assigns different experts to different GPUs. With 64 experts and 8 GPUs, each GPU is responsible for 8 experts.
The core communication operation is all_to_all, executed twice:
- Dispatch: after the router determines each token’s destination expert, all_to_all rearranges tokens from a “partitioned by data” layout to a “grouped by expert” layout — each GPU receives all tokens destined for the experts it hosts.
- Combine: after each expert processes its tokens, all_to_all sends the results back to the originating GPUs.
flowchart TD
subgraph BEFORE["Before: data-sharded"]
direction LR
G0B["GPU 0
tokens → various Experts"]
G1B["GPU 1
tokens → various Experts"]
end
BEFORE -->|"all_to_all
(dispatch)"| AFTER
subgraph AFTER["After: expert-grouped"]
direction LR
G0A["GPU 0 (E0,E1)
all tokens → E0,E1"]
G1A["GPU 1 (E2,E3)
all tokens → E2,E3"]
end
AFTER -->|"expert compute"| COMPUTE["Each GPU runs its experts"]
COMPUTE -->|"all_to_all
(combine)"| RESULT["Restore original data sharding"]
The communication cost of EP depends on the token routing distribution. If tokens are uniformly spread across all experts, all_to_all traffic is maximized. If tokens cluster on a few experts, communication is lower but load imbalance becomes a problem.
DeepSeek MoE: All-to-All Dispatch and Token Dropping
DeepSeek introduced several important refinements to the MoE architecture.
1. Fine-Grained Experts
Traditional MoE designs (e.g., Switch Transformer) use a small number of large experts. DeepSeek uses a large number of small experts — for example, 160 experts with each token activating 6, rather than 16 experts with each token activating 2. More, smaller experts provide exponentially more combinatorial flexibility:
$$\binom{160}{6} \gg \binom{16}{2}$$
The number of possible expert combinations each token can select from grows enormously, yielding greater representational power.
2. Shared Experts + Routed Experts
DeepSeek MoE introduces “shared experts” — experts that process every token — alongside the routed experts selected by the gating network:
Output = SharedExpert(x) + Σ Router_topk(RoutedExpert_i(x))
```mermaid
flowchart TD
subgraph NON_EXPERT["Non-expert layers (Attention, LN, Embedding)"]
FSDP_SHARD["FSDP across all GPUs
shard params/grads/opt_state"]
end
subgraph MOE_LAYERS["MoE layers"]
direction LR
EP_0["Expert 0-7 → GPU 0"]
EP_1["Expert 8-15 → GPU 1"]
EP_N["... (EP)"]
end
COMM["Communication:
FSDP: all_gather + reduce_scatter
EP: all_to_all (dispatch/combine)"]
NON_EXPERT --- MOE_LAYERS --- COMM
┌─────────── EP + FSDP Combined ─────────────────────────────────┐ │ │ │ Non-expert layers (Attention, LN, Embedding): │ │ → FSDP across all GPUs (shard params/grads/opt_state) │ │ │ │ MoE layers: │ │ Expert 0-7 → GPU 0 (EP) │ │ Expert 8-15 → GPU 1 (EP) │ │ … │ │ Non-expert parts of MoE layer → FSDP across all GPUs │ │ │ │ Communication: │ │ FSDP: all_gather + reduce_scatter (params/grads) │ │ EP: all_to_all (token dispatch/combine) │ └────────────────────────────────────────────────────────────────┘
flowchart LR
subgraph RING["Ring Attention — Sequence length S, 4 GPUs"]
G0["GPU 0
Q[0:S/4]"] -->|"KV"| G1["GPU 1
Q[S/4:S/2]"]
G1 -->|"KV"| G2["GPU 2
Q[S/2:3S/4]"]
G2 -->|"KV"| G3["GPU 3
Q[3S/4:S]"]
G3 -->|"KV"| G0
end
KV blocks circulate around the ring; send/recv overlaps with attention computation. ┌──────── Ring Attention (sequence length S, 4 GPUs) ───────────┐ │ │ │ GPU 0: Q[0:S/4] GPU 1: Q[S/4:S/2] │ │ GPU 2: Q[S/2:3S/4] GPU 3: Q[3S/4:S] │ │ │ │ Step 1: each GPU computes partial attention with local KV │ │ Step 2: send KV to right neighbor, recv from left, continue │ │ Step 3: continue circulating… │ │ Step 4: KV completes full circle — attention done │ │ │ │ ┌───┐ KV ┌───┐ KV ┌───┐ KV ┌───┐ │ │ │ 0 │────→│ 1 │────→│ 2 │────→│ 3 │─┐ │ │ └───┘ └───┘ └───┘ └───┘ │ │ │ ↑ │ │ │ └───────────────────────────────────┘ │ └───────────────────────────────────────────────────────────────┘
A critical optimization: the `send/recv` of KV blocks can **overlap** with attention computation — while a GPU is computing attention with the current KV block, it is already transmitting the next one.
**Memory.** Each GPU stores Q and KV for only $S/N$ of the sequence. Memory drops from $O(S^2)$ to $O(S^2/N)$ (or more precisely, from $O(S)$ to $O(S/N)$ under FlashAttention).
#### Stripe Attention
A problem with Ring Attention is **load imbalance from the causal mask**: GPUs handling the beginning of the sequence compute less attention (because the causal mask blocks later positions), while the GPU holding the end of the sequence does the most work.
Stripe Attention solves this by **interleaving** sequence positions across GPUs:
Ring Attention (contiguous assignment): GPU 0: tokens [0, 1, 2, 3] → least compute (causal) GPU 1: tokens [4, 5, 6, 7] GPU 2: tokens [8, 9, 10, 11] GPU 3: tokens [12, 13, 14, 15] → most compute
Stripe Attention (interleaved assignment): GPU 0: tokens [0, 4, 8, 12] → balanced compute GPU 1: tokens [1, 5, 9, 13] GPU 2: tokens [2, 6, 10, 14] GPU 3: tokens [3, 7, 11, 15]
flowchart TD
IN["Input: each GPU has seq/N, full heads"]
QKV["QKV projection (local)"]
A2A1["all_to_all
(seq/N, heads) → (seq, heads/N)"]
ATT["Attention
each GPU: heads/N over full sequence"]
A2A2["all_to_all
(seq, heads/N) → (seq/N, heads)"]
OUT["Output projection (local)"]
IN --> QKV --> A2A1 --> ATT --> A2A2 --> OUT
style A2A1 fill:#cce5ff
style A2A2 fill:#cce5ff
┌──────── Ulysses SP ───────────────────────────────────────┐ │ │ │ Input: each GPU holds seq/N tokens, all heads │ │ │ │ ① QKV projection (local computation) │ │ ② all_to_all: (seq/N, heads) → (seq, heads/N) │ │ transform from “split by sequence” to “split by head” │ │ ③ Attention computation (each GPU: heads/N, full seq) │ │ ④ all_to_all: (seq, heads/N) → (seq/N, heads) │ │ transform from “split by head” back to “split by seq” │ │ ⑤ Output projection (local computation) │ └───────────────────────────────────────────────────────────┘
flowchart TD
subgraph N0["Node 0 (8 GPUs) — NVLink 600 GB/s"]
direction LR
subgraph TP0A["TP=4"]
G0["0"] ~~~ G1["1"] ~~~ G2["2"] ~~~ G3["3"]
end
subgraph TP0B["TP=4"]
G4["4"] ~~~ G5["5"] ~~~ G6["6"] ~~~ G7["7"]
end
TP0L["Stage 0 — DP/FSDP across"]
end
subgraph N1["Node 1 (8 GPUs)"]
direction LR
TP1L["Stage 1 — DP/FSDP across"]
end
N0 ==>|"PP: send/recv
(InfiniBand)"| N1
style N0 fill:#f0f0f0
style N1 fill:#f0f0f0
TP: intra-node (NVLink) | PP: cross-node-group (InfiniBand) | FSDP: cross-node (InfiniBand) ┌──────── 3D Parallelism: 64 GPUs (8 nodes × 8 GPUs) ───────────┐ │ │ │ Node 0 (8 GPUs) Node 1 (8 GPUs) │ │ ┌─TP=4─┐ ┌─TP=4─┐ ┌─TP=4─┐ ┌─TP=4─┐ │ │ │0 1 2 3│ │4 5 6 7│ │0 1 2 3│ │4 5 6 7│ │ │ │Stage 0│ │Stage 0│ │Stage 1│ │Stage 1│ │ │ └───────┘ └───────┘ └───────┘ └───────┘ │ │ └── DP/FSDP across ──┘ └── DP/FSDP across ──┘ │ │ │ │ Node 2-3: Stage 2 Node 4-7: replicas for DP │ │ │ │ TP: intra-node (NVLink 600 GB/s) ← high bandwidth needed │ │ PP: across node groups (InfiniBand) ← small send/recv │ │ FSDP: across nodes (InfiniBand) ← all_gather/reduce_scatter│ └───────────────────────────────────────────────────────────────┘
The rationale is topology-aware placement:
- **TP** has the heaviest communication (all_reduce every layer), so it goes **intra-node** on NVLink.
- **PP** has the lightest communication (send/recv of activations between adjacent stages), so it can go **cross-node**.
- **FSDP/DP** falls in between, and is placed **cross-node** with overlapped communication.
#### 5D Parallelism
Adding SP (Sequence Parallel) and EP (Expert Parallel) gives the so-called 5D parallelism:
$$\text{Total GPUs} = TP \times SP \times PP \times DP \times EP$$
Each strategy splits along a different dimension:
| Strategy | Split Dimension | Communication Op | Volume | Best Interconnect |
|----------|----------------|-------------------|--------|-------------------|
| TP | Hidden dimension | all_reduce | High | Intra-node (NVLink) |
| SP | Sequence dimension | reduce_scatter / all_gather | Medium | Intra-node |
| PP | Layers | send/recv | Low | Intra- or cross-node |
| DP/FSDP | Data | all_reduce / all_gather + reduce_scatter | Medium | Cross-node |
| EP | Experts | all_to_all | Routing-dependent | Cross-node |
#### How to Choose a Parallelism Strategy
The right combination depends on three factors: **model size, hardware topology, and sequence length**.
Here is a decision tree for reference:
Does the model fit on a single GPU? ├── Yes → DDP (simplest, fastest) └── No ├── Adam state doesn’t fit, but parameters do? │ └── FSDP (SHARD_GRAD_OP) ├── Even parameters don’t fit? │ └── FSDP (FULL_SHARD) │ ├── Still OOM? │ │ └── Add TP (typically 2/4/8 within node) │ ├── Still not enough? │ │ └── Add PP (cross-node) │ └── Long sequences causing OOM? │ └── Add Context Parallel └── MoE model? └── EP for experts + FSDP for non-expert params
Some rules of thumb:
- **7B models**: 2-8 GPU DDP or FSDP is sufficient.
- **13B-70B models**: FSDP + TP (intra-node TP=2 or 4).
- **70B+ models**: FSDP + TP + PP — full 3D parallelism.
- **MoE models (e.g., Mixtral, DeepSeek)**: EP + FSDP + TP.
- **Very long sequences (128K+)**: Context Parallel + TP + FSDP.
---
## Companion Code
All code for this article is in [`code/02-parallel-strategies/`](https://github.com/mzf666/llm-infra/tree/main/code/02-parallel-strategies/):
- **`ddp_example.py`** — A complete DDP training loop with `DistributedSampler`, bucketed gradient synchronization, throughput measurement, and parameter consistency verification. Run with: `torchrun --nproc_per_node=2 ddp_example.py`
- **`fsdp_example.py`** — Compares NO_SHARD (= DDP), SHARD_GRAD_OP (= ZeRO-2), and FULL_SHARD (= ZeRO-3), showing real per-GPU memory differences. Run with: `torchrun --nproc_per_node=2 fsdp_example.py`
- **`tensor_parallel.py`** — A from-scratch implementation of Megatron-style `ColumnParallelLinear` and `RowParallelLinear`, combined into a `TensorParallelFFN`. Includes correctness verification and memory analysis. Run with: `torchrun --nproc_per_node=2 tensor_parallel.py`
All scripts support CPU mode (automatic `gloo` backend fallback), so you can study the logic without a GPU. However, memory measurements and performance numbers are only meaningful on CUDA.
---
## Summary and What Comes Next
This article systematically covered all the major parallelism strategies used in large model training. Here is a recap of the key ideas.
**Classic parallelism strategies:**
- **DDP**: replicate the model, split the data, all_reduce the gradients — the simplest approach, but saves no memory.
- **FSDP/ZeRO**: shard parameters + gradients + optimizer state, paying all_gather/reduce_scatter communication to save memory — the backbone of large model training.
- **TP**: split weight matrices by dimension (Column + Row), with 2 all_reduces per layer — requires NVLink bandwidth, best for intra-node.
- **PP**: partition by layers, using send/recv — low communication but pipeline bubbles, addressed by 1F1B and Zero Bubble schedules.
- **SP**: split activations along the sequence dimension for non-TP ops (LayerNorm, Dropout) — complementary to TP, reduces activation memory.
**MoE parallelism:**
- **EP**: distribute experts across GPUs, using all_to_all for token dispatch — load balancing is the central challenge.
- **DeepSeek MoE**: fine-grained experts + shared experts + token dropping for better utilization.
**Cutting-edge approaches:**
- **Context Parallel** (Ring / Stripe Attention): handles ultra-long sequences by distributing attention across GPUs.
- **Ulysses SP**: uses all_to_all to transpose between sequence and head dimensions.
- **Hybrid Parallelism (3D/5D)**: combines strategies based on hardware topology — TP intra-node, PP cross-node, FSDP everywhere else.
The fundamental insight that connects every strategy in this article: each one is making a different **trade-off between memory and communication**. Which combination you choose depends on how large your model is, how fast your inter-GPU links are, and how long your sequences are. There is no silver bullet — only the best engineering compromise for your specific setup.
In the next article, **"LLM Inference System Architecture,"** we shift our perspective from training to serving. Once the model is trained, how do we serve requests efficiently? We will dive deep into PagedAttention, RadixAttention (SGLang's core innovation), Continuous Batching, and other inference optimizations — exploring a fundamentally different, memory-bound challenge.
---
## References
1. **ZeRO: Memory Optimizations Toward Training Trillion Parameter Models** — Rajbhandari et al., 2020 — the theoretical foundation for FSDP, defining the Stage 1/2/3 memory sharding scheme.
2. **Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism** — Shoeybi et al., 2020 — the Column/Row partitioning approach for Tensor Parallelism.
3. **Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM** — Narayanan et al., 2021 — 3D parallelism (TP + PP + DP) system design.
4. **GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism** — Huang et al., 2019 — the micro-batching approach to Pipeline Parallelism.
5. **Zero Bubble Pipeline Parallelism** — Qi et al., 2024 — eliminating pipeline bubbles by separating B and W computations.
6. **Ring Attention with Blockwise Transformers for Near-Infinite Context** — Liu et al., 2023 — the Ring Attention approach for long sequences.
7. **DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model** — DeepSeek AI, 2024 — fine-grained MoE with shared experts.
8. **DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale** — Rajbhandari et al., 2022 — EP and FSDP combined training.
9. **Reducing Activation Recomputation in Large Transformer Models** — Korthikanti et al., 2023 — Sequence Parallel for reducing activation memory.
10. **PyTorch FSDP Documentation** — [pytorch.org/docs/stable/fsdp](https://pytorch.org/docs/stable/fsdp.html) — official FSDP/FSDP2 documentation.
11. **DeepSpeed ZeRO Tutorial** — [deepspeed.ai/tutorials/zero](https://www.deepspeed.ai/tutorials/zero/) — practical guide to ZeRO Stage 1/2/3.