Motivation

上一篇文章中,我们算过一笔账:一个 70B 参数的模型,仅 Adam 优化器相关的显存就需要 $16 \times 70 \times 10^9 \approx 1120$ GB——至少 14 张 A100 80GB 才放得下,而且还没算激活值。单卡训练大模型从物理上就不可能,我们必须把计算分布到多张 GPU 上。

但"分布到多卡"不是一句话那么简单。不同的并行策略在切什么(参数、梯度、激活值、序列)、怎么切(按层、按维度、按数据)、通信开销(all_reduce、all_gather、send/recv、all_to_all)之间做出了截然不同的取舍。选错策略,轻则浪费算力,重则根本跑不起来。

本文将从最基础的 DDP 出发,逐步覆盖 FSDP、TP、PP、SP、EP 以及前沿的 Context Parallel 和混合并行方案。读完之后,你应该能根据模型规模、硬件拓扑和训练需求,为自己的项目选出合理的并行策略组合。

前置知识

  • GPU 显存模型与分布式通信基础(第 1 篇)——特别是 NCCL 通信原语和硬件拓扑
  • PyTorch DDP 的基本使用经验(写过 torchrun 启动的训练脚本)
  • 了解 Transformer 的基本结构:Self-Attention、FFN、LayerNorm

先给出全景图,后面逐一展开:

flowchart TD
  subgraph DP["Data Parallel"]
    DDP["DDP"]
    FSDP["FSDP / ZeRO"]
  end
  subgraph MP["Model Parallel"]
    TP["TP (Tensor)"]
    SP["SP (Sequence)"]
  end
  subgraph PP["Pipeline Parallel"]
    PP1["1F1B"]
    PP2["Zero Bubble"]
  end
  subgraph EP["Expert Parallel"]
    EP1["EP (MoE)"]
    EP2["DeepSeek"]
  end

  DP & MP & PP & EP --> HYBRID["混合并行 (3D/5D)\nTP 节点内 + FSDP 跨节点 + PP 跨节点组"]

经典并行策略

DDP(Distributed Data Parallel)

DDP 是最简单、最常用的并行策略,核心思想只有三个字:复制模型,分数据

工作原理

  1. 每张 GPU 持有模型的完整副本(参数、梯度、优化器状态都完整存在)
  2. 训练数据通过 DistributedSampler 均分到各卡——每张卡只看自己的 mini-batch
  3. 各卡独立完成前向和反向传播,计算各自的梯度
  4. 通过 all_reduce 对梯度求和(或平均),确保所有卡拥有相同的聚合梯度
  5. 各卡执行相同的优化器更新,模型参数保持一致
flowchart TD
  DATA["Data Batch"] -->|DistributedSampler 均分| B0["B₀"] & B1["B₁"] & B2["B₂"] & B3["B₃"]
  B0 --> G0["GPU 0\n完整模型\n→ grad₀"]
  B1 --> G1["GPU 1\n完整模型\n→ grad₁"]
  B2 --> G2["GPU 2\n完整模型\n→ grad₂"]
  B3 --> G3["GPU 3\n完整模型\n→ grad₃"]
  G0 & G1 & G2 & G3 --> AR["all_reduce(gradients)\n唯一的通信操作"]
  AR --> OPT["optimizer.step()\n各卡执行相同更新"]

显存:DDP 不节省任何显存。每张 GPU 独立存储完整的参数($2\Phi$ bytes FP16)、梯度($2\Phi$ bytes)和优化器状态($12\Phi$ bytes for Adam),总计 $16\Phi$ bytes per GPU——跟单卡训练一样。

通信:唯一的通信操作是对梯度做 all_reduce。PyTorch DDP 使用了一个重要的优化——Bucketed All-Reduce:它不会等所有梯度都算完才通信,而是将梯度分成若干个 bucket(默认 25 MB),当一个 bucket 内所有梯度就绪后立即开始 all_reduce,与后续层的反向传播重叠执行。这大幅隐藏了通信延迟。

什么时候用 DDP:模型能放进单张 GPU 的时候。DDP 的优势是简单、高效、几乎线性扩展。它的局限也很明显——模型必须整个塞进一张卡。

以下是 DDP 的核心代码片段:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# DDP wrapping — 注册梯度同步 hook
model = DDP(model, device_ids=[local_rank])

# 训练循环与单卡完全一样
for input_ids, target_ids in dataloader:
    logits = model(input_ids)
    loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
    optimizer.zero_grad()
    loss.backward()       # ← DDP hook 在这里自动触发 all_reduce
    optimizer.step()      # ← 各卡梯度相同,更新后参数一致

完整代码见 code/02-parallel-strategies/ddp_example.py


FSDP / FSDP2(Fully Sharded Data Parallel)

DDP 的致命问题是:每张卡都存了一份完整的参数 + 梯度 + 优化器状态。当模型大到单卡放不下时,DDP 就无能为力了。

FSDP(Fully Sharded Data Parallel)的核心思想来自 DeepSpeed 的 ZeRO(Zero Redundancy Optimizer)系列论文:既然每张卡都存了冗余的参数和优化器状态,为什么不把它们切分(shard) 到各卡上,需要时再临时拼回来?

ZeRO Stage 1/2/3

ZeRO 将显存优化分为三个阶段,逐步增加切分范围:

ZeRO Stage切分内容每卡显存FSDP 对应策略
Stage 1优化器状态$4\Phi + \frac{12\Phi}{N}$
Stage 2优化器状态 + 梯度$2\Phi + \frac{14\Phi}{N}$SHARD_GRAD_OP
Stage 3优化器状态 + 梯度 + 参数$\frac{16\Phi}{N}$FULL_SHARD

其中 $\Phi$ 是参数数量,$N$ 是 GPU 数量。可以看到,Stage 3 实现了近乎线性的显存缩减——8 张卡就能把显存需求降到单卡的 1/8(加上激活值的开销)。

FSDP 的前向/反向过程

FSDP (FULL_SHARD) 的工作流程如下:

flowchart LR
  subgraph FWD["FSDP Forward Pass — Layer i"]
    direction LR
    FS1["Shard\n1/N"] -->|"all_gather"| FP1["Full Params\n(临时拼出)"]
    FP1 -->|"compute"| FO1["Output"]
    FP1 -.->|"丢弃 (N-1)/N"| X1[" "]
  end
  subgraph BWD["FSDP Backward Pass — Layer i"]
    direction LR
    BS1["Shard\n1/N"] -->|"all_gather"| BP1["Full W\n(临时)"]
    BP1 -->|"backward"| BG1["Grad\n(full)"]
    BG1 -->|"reduce_scatter"| BGS["Grad Shard\n1/N"]
  end
  FWD ~~~ BWD

  style X1 fill:none,stroke:none

FSDP 分片与通信模式

通信开销对比

策略前向通信反向通信总通信量
DDP1x all_reduce (梯度)$2\Phi$
FSDP (FULL_SHARD)all_gather (参数)all_gather (参数) + reduce_scatter (梯度)$3\Phi$

FSDP 的通信量约为 DDP 的 1.5 倍——这就是用通信换显存的代价。但在大模型场景下,这个 trade-off 非常值得:没有 FSDP,根本跑不起来。

FSDP2:PyTorch 2.2+ 的 Composable API

PyTorch >= 2.2 引入了 FSDP2(torch.distributed._composable.fsdp),相比 FSDP1 的主要改进:

  • Per-parameter sharding:不再要求整个 module 作为 FSDP 单元,可以对单个参数做 sharding
  • Composable:可以与 TP、PP 等其他并行策略自由组合,不需要嵌套包装
  • 更灵活的 sharding 粒度:不同的层可以使用不同的 sharding 策略

核心概念与 FSDP1 完全一致(all_gather/reduce_scatter 的通信模式不变),API 更现代、与 PyTorch 2.x 的编译器栈更兼容。

以下是 FSDP 不同策略的显存对比代码片段:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

# 三种策略对比
strategies = [
    (ShardingStrategy.NO_SHARD, "NO_SHARD (= DDP)"),
    (ShardingStrategy.SHARD_GRAD_OP, "SHARD_GRAD_OP (= ZeRO-2)"),
    (ShardingStrategy.FULL_SHARD, "FULL_SHARD (= ZeRO-3)"),
]

for strategy, name in strategies:
    model = FSDP(
        TransformerLM(),
        sharding_strategy=strategy,
        device_id=local_rank,
    )
    # ... 训练并测量显存

完整代码见 code/02-parallel-strategies/fsdp_example.py


TP(Tensor Parallel)— Megatron 列/行切分

FSDP 将参数"切碎"再"拼回",通信和计算是串行的。Tensor Parallel 更进一步:直接把每一层的权重矩阵按维度切开,分给不同 GPU 各算一部分,从根本上减少单卡的计算量和显存。

TP 的核心思想来自 Megatron-LM,定义了两种基本的并行线性层:

ColumnParallelLinear:按输出维度切分

对于线性层 $Y = XW + b$,将权重 $W \in \mathbb{R}^{d \times h}$ 按列切分

flowchart TD
  W["完整权重 W: (d_model × dim_ffn)"]
  W -->|"按列切分到 2 张 GPU"| G0 & G1
  G0["GPU 0: W₀ = W[:, :dim_ffn//2]\nY₀ = X @ W₀"]
  G1["GPU 1: W₁ = W[:, dim_ffn//2:]\nY₁ = X @ W₁"]
  NOTE["每张 GPU 独立计算,无需通信!\n(因为输入 X 在所有 GPU 上是相同的)"]
  style NOTE fill:#d4edda,stroke:#155724

关键优势:前向传播不需要任何通信。每张 GPU 得到输出的一个分片(chunk),可以直接送入后续的逐元素操作(如 GeLU)。

RowParallelLinear:按输入维度切分

将权重 $W \in \mathbb{R}^{h \times d}$ 按行切分

flowchart TD
  subgraph SPLIT["按行切分到 2 张 GPU"]
    G0["GPU 0: W₀ = W[:dim_ffn//2, :]\nY₀ = X₀ @ W₀ (部分和)"]
    G1["GPU 1: W₁ = W[dim_ffn//2:, :]\nY₁ = X₁ @ W₁ (部分和)"]
  end
  SPLIT -->|"all_reduce"| RESULT["Y = Y₀ + Y₁"]
  style RESULT fill:#fff3cd,stroke:#856404

关键约束:每张 GPU 只计算了输出的一个部分和(partial sum),必须通过 all_reduce 才能得到完整输出。

Megatron FFN:Column + Row = 只需 1 次 all_reduce

Megatron-LM 的天才设计在于将 Column 和 Row 配对使用:

flowchart TD
  X["Input X\n(每张 GPU 上相同)"]
  X --> COL["ColumnParallelLinear\n(W1 按列切, 无通信)"]
  COL --> GELU["GeLU\n(逐元素, 无通信)"]
  GELU --> ROW["RowParallelLinear\n(W2 按行切)"]
  ROW -->|"all_reduce\n合并部分和"| Y["Output Y\n(每张 GPU 上相同)"]

  style COL fill:#d4edda
  style GELU fill:#d4edda
  style ROW fill:#fff3cd

整个 FFN 块只需 1 次 all_reduce

Tensor Parallel 的列切分与行切分

Attention 的 TP

对于 Multi-Head Attention,TP 的做法同样优雅:

  • Q、K、V 投影:使用 ColumnParallel,将 attention head 分给不同 GPU——每张 GPU 负责 $\frac{n_heads}{TP}$ 个 head
  • Output 投影:使用 RowParallel,将各 GPU 的 head 输出合并

这样整个 Attention 块也只需要1 次 all_reduce。一个 Transformer 层总共 2 次 all_reduce(FFN 1 次 + Attention 1 次)。

TP 的适用条件

TP 通信频繁(每层 2 次 all_reduce),对带宽要求极高。因此:

  • TP degree 通常为 2、4 或 8,部署在同一节点内(NVLink 600 GB/s)
  • 跨节点做 TP 几乎不可行(InfiniBand 25-50 GB/s,太慢)
  • hidden dimension 和 head 数必须能被 TP degree 整除

以下是 ColumnParallelLinear 的核心实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, rank):
        super().__init__()
        self.out_features_per_rank = out_features // world_size
        # 每张 GPU 只存 1/world_size 的权重
        self.weight = nn.Parameter(
            torch.empty(self.out_features_per_rank, in_features)
        )
        self.bias = nn.Parameter(torch.empty(self.out_features_per_rank))

    def forward(self, x):
        # 无通信!各 GPU 独立计算
        return F.linear(x, self.weight, self.bias)

完整代码见 code/02-parallel-strategies/tensor_parallel.py


PP(Pipeline Parallel)— 1F1B 与 Zero Bubble

TP 把每一层切开,PP 的思路完全不同:按层划分,把模型的不同层分配到不同 GPU 上(称为 stage)。

flowchart LR
  S0["Stage 0 (GPU 0)\nLayer 0-10"] -->|"send/recv\nactivations"| S1["Stage 1 (GPU 1)\nLayer 11-21"]
  S1 -->|"send/recv\nactivations"| S2["Stage 2 (GPU 2)\nLayer 22-31"]

通信方式: send/recv(点对点,仅相邻 stage 之间)。

通信:PP 只需要相邻 stage 之间的 send/recv——传输的是中间激活值(前向)和梯度(反向),通信量远小于 TP 的 all_reduce。这使得 PP 非常适合跨节点部署。

但 PP 有一个致命问题:Pipeline Bubble(流水线气泡)

Naive PP:巨大的 Bubble

最朴素的方式:整个 batch 依次通过各 stage,一个 stage 在计算时,其他 stage 全部闲置。

flowchart TD
  subgraph NaivePP["Naive Pipeline Parallel — Time →"]
    direction LR
    subgraph GPU0["GPU 0"]
      direction LR
      F0["Forward"] ~~~ B0["Backward"]
    end
    subgraph GPU1["GPU 1"]
      direction LR
      IDLE1a["idle"] ~~~ F1["Forward"] ~~~ IDLE1b["idle"] ~~~ B1["Backward"]
    end
    subgraph GPU2["GPU 2"]
      direction LR
      IDLE2a["idle"] ~~~ F2["Forward"] ~~~ IDLE2b["idle"] ~~~ B2["Backward"]
    end
  end
  NOTE["Bubble 占比 ≈ (P-1)/P\n4 个 stage → 75% 时间浪费!"]
  style IDLE1a fill:#f8d7da,stroke:#dc3545
  style IDLE1b fill:#f8d7da,stroke:#dc3545
  style IDLE2a fill:#f8d7da,stroke:#dc3545
  style IDLE2b fill:#f8d7da,stroke:#dc3545
  style F0 fill:#d4edda,stroke:#28a745
  style F1 fill:#d4edda,stroke:#28a745
  style F2 fill:#d4edda,stroke:#28a745
  style B0 fill:#cce5ff,stroke:#007bff
  style B1 fill:#cce5ff,stroke:#007bff
  style B2 fill:#cce5ff,stroke:#007bff
  style NOTE fill:#fff3cd,stroke:#856404

假设有 $P$ 个 stage,bubble 占比约为 $\frac{P-1}{P}$——4 个 stage 意味着 75% 的时间在浪费!

GPipe:Micro-batching

GPipe 的解决方案:将一个 mini-batch 切成 $M$ 个 micro-batch,让多个 micro-batch 像流水线一样在各 stage 间流动。

flowchart TD
  subgraph GPipe["GPipe — M=4 Micro-batches, Time →"]
    direction LR
    subgraph G0["GPU 0"]
      direction LR
      G0F["F0 F1 F2 F3"] ~~~ G0GAP["· · ·"] ~~~ G0B["B3 B2 B1 B0"]
    end
    subgraph G1["GPU 1"]
      direction LR
      G1F["F0 F1 F2 F3"] ~~~ G1B["B3 B2 B1 B0"]
    end
    subgraph G2["GPU 2"]
      direction LR
      G2F["F0 F1 F2 F3"] ~~~ G2B["B3 B2 B1 B0"]
    end
    subgraph G3["GPU 3"]
      direction LR
      G3F["F0 F1 F2 F3"] --- G3B["B3 B2 B1 B0"]
    end
  end
  NOTE["Bubble: (P-1)/(M+P-1) — 当 M >> P 时, bubble → 0"]
  style G0F fill:#d4edda,stroke:#28a745
  style G1F fill:#d4edda,stroke:#28a745
  style G2F fill:#d4edda,stroke:#28a745
  style G3F fill:#d4edda,stroke:#28a745
  style G0B fill:#cce5ff,stroke:#007bff
  style G1B fill:#cce5ff,stroke:#007bff
  style G2B fill:#cce5ff,stroke:#007bff
  style G3B fill:#cce5ff,stroke:#007bff
  style G0GAP fill:#f8d7da,stroke:#dc3545
  style NOTE fill:#fff3cd,stroke:#856404

Bubble 占比从 $\frac{P-1}{P}$ 降到 $\frac{P-1}{M+P-1}$。但 GPipe 的问题是:所有 micro-batch 的前向都做完后才开始反向,需要同时保存所有 micro-batch 的激活值,显存开销巨大

1F1B Schedule

1F1B(1 Forward 1 Backward)交错执行前向和反向,每做完一个 micro-batch 的前向后就尽快做反向,从而释放激活值显存:

flowchart TD
  subgraph OneF1B["1F1B Schedule — P=4, M=8, Time →"]
    direction TB
    G0["GPU 0: F0 F1 F2 F3 | B0 F4 B1 F5 B2 F6 B3 F7 | B4 B5 B6 B7"]
    G1["GPU 1: · F0 F1 F2 | B0 F3 B1 F4 B2 F5 B3 F6 | B4 F7 B5 B6 B7"]
    G2["GPU 2: · · F0 F1 | B0 F2 B1 F3 B2 F4 B3 F5 | B4 F6 B5 F7 ..."]
    G3["GPU 3: · · · F0 | B0 F1 B1 F2 B2 F3 B3 F4 | B4 F5 B5 ..."]
  end
  NOTE["稳态阶段: 1 Forward + 1 Backward 交错\n显存占用远小于 GPipe"]
  style NOTE fill:#d4edda,stroke:#155724

1F1B 在稳态阶段(warmup 结束后),每个 GPU 同时只保留有限数量的 micro-batch 的激活值,显存占用远小于 GPipe。Bubble 比例不变(仍为 $\frac{P-1}{M+P-1}$),但显存显著改善。

Zero Bubble PP

2024 年提出的 Zero Bubble PP 进一步减少气泡。核心思想是:将反向传播分解为两部分——计算输入梯度(B)计算权重梯度(W)。B 需要传递给前一个 stage,但 W 不需要通信,可以填充到 bubble 中:

flowchart TD
  subgraph ZB["Zero Bubble PP — Time →"]
    direction TB
    G0["GPU 0: F · F · F · B · F · B · W · B · W · B · W"]
    G1["GPU 1: · F · F · B · F · B · W · B · W · B · W"]
  end
  NOTE["W = 权重梯度计算 (不需要通信)\nW 填充了原本的 bubble → 气泡率 ≈ 0"]
  style NOTE fill:#d4edda,stroke:#155724

Zero Bubble PP 在理论上可以将 bubble 率降到接近零,代价是实现复杂度增加和更精细的调度。

Pipeline Parallel 调度策略对比

PP 小结

调度策略Bubble 占比激活值显存实现复杂度
Naive$(P-1)/P$
GPipe$(P-1)/(M+P-1)$高(所有 micro-batch)
1F1B$(P-1)/(M+P-1)$中(有限 micro-batch)
Zero Bubble$\approx 0$

SP(Sequence Parallel)— LayerNorm/Dropout 序列维度切分

TP 将线性层的权重按维度切分,但 Transformer 中还有不少操作是逐元素的,不涉及权重矩阵——比如 LayerNorm、Dropout、残差连接。这些操作在 TP 下的问题是:它们需要在完整的隐藏维度上执行,因此每张 GPU 都要持有完整的激活值,造成激活值显存的冗余

Sequence Parallel(SP)解决这个问题的方式是:对于这些非 TP 操作,改为在序列维度上切分:

flowchart TD
  LN1["LayerNorm — SP: seq/N"]
  AG1["all_gather: 拼出完整序列"]
  ATT["Attention — TP: 按 head 切分"]
  RS1["reduce_scatter: 切分到序列维度"]
  DR1["Dropout + Residual — SP: seq/N"]
  LN2["LayerNorm — SP"]
  AG2["all_gather: 拼出完整序列"]
  FFN["FFN — TP: Column + Row"]
  RS2["reduce_scatter: 切回序列维度"]
  DR2["Dropout + Residual — SP"]

  LN1 --> AG1 --> ATT --> RS1 --> DR1 --> LN2 --> AG2 --> FFN --> RS2 --> DR2

  style LN1 fill:#d4edda
  style DR1 fill:#d4edda
  style LN2 fill:#d4edda
  style DR2 fill:#d4edda
  style ATT fill:#fff3cd
  style FFN fill:#fff3cd
  style AG1 fill:#cce5ff
  style AG2 fill:#cce5ff
  style RS1 fill:#cce5ff
  style RS2 fill:#cce5ff

关键洞察:TP 中的 all_reduce 被拆分为 reduce_scatter + all_gather,分别放在 TP 区域的出口和入口。这样通信量不变(all_reduce = reduce_scatter + all_gather),但 SP 区域的激活值显存降到了 $1/N$

SP 的价值在大 batch、长序列时尤为显著——此时激活值是显存的主要来源,SP 直接将这部分开销除以 TP degree。


MoE 并行

EP(Expert Parallel)

Mixture-of-Experts(MoE)在模型中引入了一组"专家"子网络,每个 token 只激活其中的 $k$ 个(通常 $k=1$ 或 $2$)。这使得 MoE 可以在显著增大参数量的同时保持计算量基本不变——但也带来了独特的并行挑战。

MoE 层的结构:

flowchart TD
  INPUT["Input tokens"]
  ROUTER["Router
决定每个 token 发给哪个 expert"]
  E0["Expert 0
FFN"] & E1["Expert 1
FFN"] & E2["Expert 2
FFN"] & E3["Expert 3
FFN ..."]
  COMBINE["Combine outputs"]

  INPUT --> ROUTER
  ROUTER --> E0 & E1 & E2 & E3
  E0 & E1 & E2 & E3 --> COMBINE

Expert Parallel(EP) 将不同的 expert 分配到不同 GPU 上。如果有 64 个 expert 和 8 张 GPU,每张 GPU 负责 8 个 expert。

核心通信操作是 all_to_all,执行两次:

  1. Dispatch(分发):Router 决定每个 token 要去哪个 expert 后,通过 all_to_all 将 token 从"按数据分片"的分布重新排列为"按 expert 分组"的分布——每张 GPU 收到所有发给它持有的 expert 的 token
  2. Combine(回收):各 expert 计算完成后,通过 all_to_all 将结果送回原来的 GPU
flowchart TD
  subgraph BEFORE["Before: 按数据分片"]
    direction LR
    G0B["GPU 0
tokens → 各 Expert"]
    G1B["GPU 1
tokens → 各 Expert"]
  end
  BEFORE -->|"all_to_all
(dispatch)"| AFTER
  subgraph AFTER["After: 按 expert 分组"]
    direction LR
    G0A["GPU 0 (E0,E1)
收到所有发往 E0,E1 的 token"]
    G1A["GPU 1 (E2,E3)
收到所有发往 E2,E3 的 token"]
  end
  AFTER -->|"expert 计算"| COMPUTE["各 GPU 运行自己的 expert"]
  COMPUTE -->|"all_to_all
(combine)"| RESULT["恢复原始数据分片"]

EP 的通信开销取决于 token 的路由分布——如果 token 均匀分散到各 expert,all_to_all 的通信量最大;如果 token 集中在少数 expert,通信量较小但会导致负载不均。


DeepSeek MoE:All-to-All Dispatch 与 Token Dropping

DeepSeek 在 MoE 架构上做了几个重要改进:

1. Fine-grained Experts(细粒度专家)

传统 MoE(如 Switch Transformer)使用少量大 expert,DeepSeek 使用大量小 expert——例如 160 个 expert 每个 token 激活 6 个,而非 16 个 expert 每个 token 激活 2 个。更多更小的 expert 提供了更灵活的组合能力:

$$\binom{160}{6} \gg \binom{16}{2}$$

token 可以组合出的 expert 组合数指数级增加,表达能力更强。

2. Shared Experts + Routed Experts

DeepSeek MoE 引入了"共享专家"——每个 token 都会经过的 expert,加上通过 router 选择的 expert:

Output = SharedExpert(x) + Σ Router_topk(RoutedExpert_i(x))
flowchart TD
  subgraph NON_EXPERT["Non-expert layers (Attention, LN, Embedding)"]
    FSDP_SHARD["FSDP across all GPUs
shard params/grads/opt_state"]
  end
  subgraph MOE_LAYERS["MoE layers"]
    direction LR
    EP_0["Expert 0-7 → GPU 0"]
    EP_1["Expert 8-15 → GPU 1"]
    EP_N["... (EP)"]
  end
  COMM["通信:
FSDP: all_gather + reduce_scatter
EP: all_to_all (dispatch/combine)"]
  NON_EXPERT --- MOE_LAYERS --- COMM

PyTorch 2.x 的 FSDP2 和 DTensor 框架为这种混合并行提供了原生支持。


前沿方案

Context Parallel:Ring Attention / Stripe Attention

随着 LLM 处理越来越长的上下文(128K、1M tokens),序列长度成为了新的显存瓶颈。Self-Attention 的显存和计算复杂度为 $O(S^2)$($S$ 为序列长度),一条 128K 的序列在 FlashAttention 下仍然需要巨量的显存来存储 KV 。

Context Parallel(CP)的解决方案:将长序列切分到多张 GPU 上,每张 GPU 只处理序列的一个片段,通过通信交换 KV 来完成完整的 attention 计算。

Ring Attention

Ring Attention 的核心思想与 Ring All-Reduce 类似:将 GPU 排成一个环,每张 GPU 持有一段序列的 Q,同时 KV block 在环上循环传递:

flowchart LR
  subgraph RING["Ring Attention — 序列长度 S, 4 GPUs"]
    G0["GPU 0
Q[0:S/4]"] -->|"KV"| G1["GPU 1
Q[S/4:S/2]"]
    G1 -->|"KV"| G2["GPU 2
Q[S/2:3S/4]"]
    G2 -->|"KV"| G3["GPU 3
Q[3S/4:S]"]
    G3 -->|"KV"| G0
  end

KV 在环上循环传递,每步 send/recv 与 attention 计算重叠执行。

关键优化:KV 的 send/recv 与 attention 计算可以重叠——当 GPU 在用当前 KV block 计算 attention 时,已经在传输下一个 KV block 了。

显存:每张 GPU 只存 $S/N$ 长度的 Q 和对应的 KV,显存从 $O(S^2)$ 降到 $O(S^2/N)$(更准确地说,FlashAttention 下从 $O(S)$ 降到 $O(S/N)$)。

Stripe Attention

Ring Attention 的一个问题是 causal mask 导致的负载不均衡:排在序列前面的 GPU 需要计算更少的 attention(因为 causal mask 屏蔽了后续位置),导致最后一个 GPU 的计算量最大。

Stripe Attention 通过交错分配序列位置来解决:

flowchart TD
  subgraph RING["Ring Attention (连续分配)"]
    direction TB
    RG0["GPU 0: tokens 0,1,2,3 → 计算量最少 (causal)"]
    RG1["GPU 1: tokens 4,5,6,7"]
    RG2["GPU 2: tokens 8,9,10,11"]
    RG3["GPU 3: tokens 12,13,14,15 → 计算量最多"]
  end
  subgraph STRIPE["Stripe Attention (交错分配)"]
    direction TB
    SG0["GPU 0: tokens 0,4,8,12 → 计算量均衡"]
    SG1["GPU 1: tokens 1,5,9,13 → 计算量均衡"]
    SG2["GPU 2: tokens 2,6,10,14 → 计算量均衡"]
    SG3["GPU 3: tokens 3,7,11,15 → 计算量均衡"]
  end
  style RG0 fill:#d4edda,stroke:#155724
  style RG3 fill:#f8d7da,stroke:#dc3545
  style SG0 fill:#d4edda,stroke:#155724
  style SG1 fill:#d4edda,stroke:#155724
  style SG2 fill:#d4edda,stroke:#155724
  style SG3 fill:#d4edda,stroke:#155724
flowchart TD
  IN["输入: 每 GPU 持有 seq/N, 完整 heads"]
  QKV["QKV 投影 (本地计算)"]
  A2A1["all_to_all
(seq/N, heads) → (seq, heads/N)"]
  ATT["Attention
每 GPU: heads/N 的完整序列"]
  A2A2["all_to_all
(seq, heads/N) → (seq/N, heads)"]
  OUT["Output 投影 (本地计算)"]

  IN --> QKV --> A2A1 --> ATT --> A2A2 --> OUT

  style A2A1 fill:#cce5ff
  style A2A2 fill:#cce5ff

Ulysses vs Ring Attention

特性Ring AttentionUlysses SP
通信模式send/recv (P2P, 多步)all_to_all (两次)
通信量$O(S \cdot d)$$O(S \cdot d)$
通信-计算重叠容易(ring 结构天然重叠)较难
对 head 数的要求head 数必须被 SP degree 整除
适用场景极长序列,head 数不够分head 数足够时更带宽高效

两种方案在不同配置下各有优劣,实际系统中有时会结合使用。


混合并行(Hybrid Parallelism)

现实中的大模型训练几乎不会只用单一并行策略——而是将多种策略组合,根据硬件拓扑分层部署。这就是所谓的 3D 并行甚至 5D 并行

经典 3D 并行

最基础的混合方案是 TP + PP + DP(或 FSDP):

flowchart TD
  subgraph N0["Node 0 (8 GPUs) — NVLink 600 GB/s"]
    direction LR
    subgraph TP0A["TP=4"]
      G0["0"] ~~~ G1["1"] ~~~ G2["2"] ~~~ G3["3"]
    end
    subgraph TP0B["TP=4"]
      G4["4"] ~~~ G5["5"] ~~~ G6["6"] ~~~ G7["7"]
    end
    TP0L["Stage 0 — DP/FSDP across"]
  end
  subgraph N1["Node 1 (8 GPUs)"]
    direction LR
    TP1L["Stage 1 — DP/FSDP across"]
  end
  N0 ==>|"PP: send/recv
(InfiniBand)"| N1

  style N0 fill:#f0f0f0
  style N1 fill:#f0f0f0

TP: 节点内 (NVLink) | PP: 跨节点组 (InfiniBand) | FSDP: 跨节点 (InfiniBand)

5D 并行

加上 SP(Sequence Parallel)和 EP(Expert Parallel),就形成了所谓的 5D 并行:

$$\text{Total GPUs} = TP \times SP \times PP \times DP \times EP$$

每种并行在不同维度上切分:

并行策略切分维度通信操作通信量适合的互联层级
TP隐藏维度all_reduce节点内 (NVLink)
SP序列维度reduce_scatter / all_gather节点内
PPsend/recv节点内或跨节点
DP/FSDP数据all_reduce / all_gather + reduce_scatter跨节点
EPExpertall_to_all取决于路由跨节点

如何选择并行策略组合?

实际选择取决于三个因素:模型规模、硬件拓扑、序列长度

以下是一个决策参考:

flowchart TD
  Q1{"模型能放进单卡?"}
  Q1 -->|"Yes"| DDP["DDP(最简单、最快)"]
  Q1 -->|"No"| Q2{"Adam 状态放不下\n但参数放得下?"}
  Q2 -->|"Yes"| FSDP_GO["FSDP (SHARD_GRAD_OP)"]
  Q2 -->|"No"| Q3{"参数都放不下?"}
  Q3 -->|"Yes"| FSDP_FS["FSDP (FULL_SHARD)"]
  Q3 -->|"MoE 模型"| MOE["EP for experts\n+ FSDP for non-expert params"]
  FSDP_FS --> Q4{"仍然 OOM?"}
  Q4 -->|"Yes"| TP["加 TP (通常 2/4/8 within node)"]
  TP --> Q5{"还是不够?"}
  Q5 -->|"Yes"| PP["加 PP (跨节点)"]
  FSDP_FS --> Q6{"序列太长导致 OOM?"}
  Q6 -->|"Yes"| CP["加 Context Parallel"]

  style DDP fill:#d4edda,stroke:#155724
  style FSDP_GO fill:#d4edda,stroke:#155724
  style FSDP_FS fill:#cce5ff,stroke:#004085
  style MOE fill:#fff3cd,stroke:#856404
  style TP fill:#cce5ff,stroke:#004085
  style PP fill:#cce5ff,stroke:#004085
  style CP fill:#cce5ff,stroke:#004085

一些经验法则:

  • 7B 模型:2-8 GPU DDP 或 FSDP 就够了
  • 13B-70B 模型:FSDP + TP(节点内 TP=2 或 4)
  • 70B+ 模型:FSDP + TP + PP,完整的 3D 并行
  • MoE 模型(如 Mixtral、DeepSeek):EP + FSDP + TP
  • 超长序列(128K+):Context Parallel + TP + FSDP

配套代码

本文配套代码位于 code/02-parallel-strategies/

  • ddp_example.py — DDP 完整训练循环,包含 DistributedSampler、bucketed gradient sync、吞吐量测量和参数一致性验证。运行方式:torchrun --nproc_per_node=2 ddp_example.py

  • fsdp_example.py — 对比 NO_SHARD(=DDP)、SHARD_GRAD_OP(=ZeRO-2)、FULL_SHARD(=ZeRO-3)三种策略的实际显存差异。运行后可以直观看到 FULL_SHARD 的显存节省效果。运行方式:torchrun --nproc_per_node=2 fsdp_example.py

  • tensor_parallel.py — 从零实现 Megatron 风格的 ColumnParallelLinearRowParallelLinear,并组合成 TensorParallelFFN。包含正确性验证和显存分析。运行方式:torchrun --nproc_per_node=2 tensor_parallel.py

所有代码支持 CPU 模式(自动使用 gloo backend),方便在没有 GPU 的环境下学习逻辑。但显存测量和性能数据仅在 CUDA 环境下有意义。


总结与下一步

本文系统梳理了大模型训练中的所有主流并行策略。让我们回顾核心要点:

经典并行策略

  • DDP:复制模型、分数据、all_reduce 梯度——最简单,但不省显存
  • FSDP/ZeRO:切分参数+梯度+优化器状态,用 all_gather/reduce_scatter 通信换显存——大模型训练的基石
  • TP:将权重矩阵按维度切开(Column+Row),每层 2 次 all_reduce——需要 NVLink 高带宽,适合节点内
  • PP:按层划分,send/recv 通信——通信量小但有 pipeline bubble,1F1B 和 Zero Bubble 在努力消除
  • SP:序列维度切分 LayerNorm/Dropout,与 TP 互补——降低激活值显存

MoE 并行

  • EP:将 expert 分布到不同 GPU,all_to_all 做 token dispatch——负载均衡是关键挑战
  • DeepSeek MoE:细粒度 expert + 共享 expert + token dropping

前沿方案

  • Context Parallel(Ring/Stripe Attention):处理超长序列
  • Ulysses SP:all_to_all 做序列-head 维度转换
  • 混合并行(3D/5D):根据硬件拓扑分层组合

核心洞察:每种并行策略本质上都是在显存通信之间做 trade-off。选择哪种组合,取决于你的模型有多大、卡间带宽有多快、序列有多长。没有银弹,只有工程上的最优权衡。

下一篇文章**《LLM 推理系统架构》**将把视角从训练转向推理——当模型训好之后,如何高效地服务请求?我们将深入 PagedAttention、RadixAttention(SGLang 的核心创新)、Continuous Batching 等推理优化技术,看看推理系统如何解决一个全新的 Memory-bound 挑战。


参考资料

  1. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models — Rajbhandari et al., 2020 — FSDP 的理论基础,定义了 Stage 1/2/3 的显存切分策略
  2. Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism — Shoeybi et al., 2020 — Tensor Parallel 的 Column/Row 切分方案
  3. Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM — Narayanan et al., 2021 — 3D 并行(TP+PP+DP)的系统设计
  4. GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism — Huang et al., 2019 — Pipeline Parallel 的 micro-batching 方案
  5. Zero Bubble Pipeline Parallelism — Qi et al., 2024 — 通过分离 B 和 W 计算消除 pipeline bubble
  6. Ring Attention with Blockwise Transformers for Near-Infinite Context — Liu et al., 2023 — 长序列的 Ring Attention 方案
  7. DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — DeepSeek AI, 2024 — 细粒度 MoE + 共享 expert 架构
  8. DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale — Rajbhandari et al., 2022 — EP 与 FSDP 联合训练
  9. Reducing Activation Recomputation in Large Transformer Models — Korthikanti et al., 2023 — Sequence Parallel 减少激活值显存
  10. PyTorch FSDP Documentationpytorch.org/docs/stable/fsdp — FSDP/FSDP2 官方文档
  11. DeepSpeed ZeRO Tutorialdeepspeed.ai/tutorials/zero — ZeRO Stage 1/2/3 实践指南