[{"content":"Motivation Every optimization technique in the LLM infrastructure stack ultimately comes down to two things: memory and communication. Whether you are sharding model weights across GPUs with FSDP, overlapping computation with gradient synchronization in DDP, or designing a KV cache eviction policy for an inference server, you are wrestling with the same two constraints.\nMemory determines what fits on a single device. Communication determines how fast multiple devices can cooperate. If you do not understand these two fundamentals at a hardware level, every distributed training paper and every inference system will feel like a black box.\nThis article builds that foundation. Part A dissects the GPU memory hierarchy — from registers to HBM — and explains why memory bandwidth, not compute, is usually the bottleneck for LLM workloads. Part B covers the complete set of NCCL collective communication primitives, the algorithms that implement them, and the hardware interconnects that constrain them. By the end, you will have the mental models needed to reason about every parallelism strategy and inference optimization we cover in later articles.\nPrerequisites Basic PyTorch experience (training loops, nn.Module, CUDA tensors) Familiarity with the Transformer architecture (attention, FFN, residual connections) Part A: GPU Memory Model GPU vs CPU: Different Design Philosophies Before diving into GPU memory, it helps to understand why GPUs are built the way they are.\nA modern CPU like an Intel Xeon or AMD EPYC has a small number of powerful cores (typically 32-128). Each core has deep pipelines, sophisticated branch predictors, large private caches (often 1-2 MB of L2 per core), and out-of-order execution engines. The design philosophy is latency optimization: make a single thread of execution as fast as possible.\nA GPU takes the opposite approach. An NVIDIA A100 has 108 Streaming Multiprocessors (SMs), each capable of running thousands of threads concurrently. Individual threads are simple — no branch prediction, no out-of-order execution, modest cache per SM. The design philosophy is throughput optimization: keep thousands of threads in flight to hide memory latency through massive parallelism.\nflowchart LR subgraph CPU[\u0026#34;CPU: Latency-Oriented\u0026#34;] direction TB C0[\u0026#34;Few powerful cores\u0026#34;] ~~~ C1[\u0026#34;Large caches\u0026#34;] C2[\u0026#34;Branch prediction\u0026#34;] ~~~ C3[\u0026#34;Out-of-order exec\u0026#34;] end subgraph GPU[\u0026#34;GPU: Throughput-Oriented\u0026#34;] direction TB G0[\u0026#34;Many simple cores\u0026#34;] ~~~ G1[\u0026#34;Small caches per SM\u0026#34;] G2[\u0026#34;No branch prediction\u0026#34;] ~~~ G3[\u0026#34;In-order execution\u0026#34;] end This distinction matters because it explains the GPU memory hierarchy. A CPU can afford to stall on a cache miss — the out-of-order engine finds other work to do. A GPU cannot afford to stall 108 SMs simultaneously; instead, it relies on having enough threads ready to execute that the hardware can switch to a different group of threads while the first group waits for data from memory.\nStreaming Multiprocessors, Warps, and the Execution Model The fundamental execution unit of an NVIDIA GPU is the Streaming Multiprocessor (SM). Understanding SM architecture is essential for reasoning about memory access patterns and occupancy.\nSM Architecture. Each SM contains:\nMultiple CUDA cores (FP32/FP64 units) — 64 FP32 cores per SM on A100 Tensor Cores for matrix multiply-accumulate operations (crucial for Transformer training) A register file (typically 256 KB per SM on A100) Shared memory / L1 cache (configurable, up to 164 KB combined on A100) Warp schedulers that manage thread execution Warps. Threads on a GPU do not execute individually. They execute in groups of 32 called warps. All 32 threads in a warp execute the same instruction at the same time (SIMT — Single Instruction, Multiple Threads). If threads within a warp take different branches of an if statement, both branches execute serially (called warp divergence), wasting cycles. This is why GPU code avoids divergent control flow.\nThread Hierarchy. CUDA organizes threads into a three-level hierarchy:\nGrid (the entire kernel launch) └── Block (a.k.a. thread block or CTA — assigned to one SM) └── Warp (32 threads — the actual execution unit) └── Thread (individual thread within a warp) A thread block is assigned to a single SM and cannot migrate. Multiple blocks can run on the same SM if resources (registers, shared memory) allow. The number of active warps per SM relative to the maximum is called occupancy — higher occupancy generally means better latency hiding, though the relationship is not always linear.\nWhy this matters for LLMs. Transformer operations like matrix multiplications and attention computations map well to the GPU\u0026rsquo;s SIMT model because they are highly data-parallel. However, operations like layer normalization, softmax, and token-by-token autoregressive decoding have less parallelism, and understanding the warp/SM model helps explain why these operations become bottlenecks.\nMemory Hierarchy: From Registers to HBM The GPU memory hierarchy is the single most important concept for understanding LLM performance. Each level trades off capacity for speed:\nflowchart TD REG[\u0026#34;\u0026lt;b\u0026gt;Registers\u0026lt;/b\u0026gt;\\n256 KB | ~19 TB/s | 0 cycles\\nScope: per-thread\u0026#34;] SMEM[\u0026#34;\u0026lt;b\u0026gt;Shared Memory / L1 Cache\u0026lt;/b\u0026gt;\\nUp to 164 KB (A100) | ~19 TB/s\\nScope: per-block\u0026#34;] L2[\u0026#34;\u0026lt;b\u0026gt;L2 Cache\u0026lt;/b\u0026gt;\\n40 MB (A100) | ~5 TB/s\\nScope: entire GPU\u0026#34;] HBM[\u0026#34;\u0026lt;b\u0026gt;HBM (Global Memory)\u0026lt;/b\u0026gt;\\n80 GB (A100) | ~2 TB/s\\nScope: entire GPU\u0026#34;] REG --\u0026gt;|\u0026#34;~4x bandwidth drop\u0026#34;| SMEM --\u0026gt;|\u0026#34;~4x\u0026#34;| L2 --\u0026gt;|\u0026#34;~2.5x\u0026#34;| HBM style REG fill:#2d6a4f,color:#fff style SMEM fill:#40916c,color:#fff style L2 fill:#74c69d,color:#000 style HBM fill:#b7e4c7,color:#000 Let us walk through each level with concrete numbers for the NVIDIA A100 (80 GB variant):\nRegisters are the fastest storage, private to each thread. The A100 has 256 KB of register file per SM (65,536 32-bit registers). Registers have effectively zero latency — they are read in the same cycle as the instruction that uses them. However, if a kernel uses too many registers per thread, fewer warps can be active on the SM, reducing occupancy. This tension is called register pressure.\nShared Memory / L1 Cache is a fast on-chip SRAM shared among all threads in a block. On the A100, each SM has up to 164 KB of combined shared memory and L1 cache, with a configurable split. Shared memory is explicitly managed by the programmer (in CUDA C) and provides ~19 TB/s bandwidth per SM. It is the key to writing fast GPU kernels — FlashAttention, for instance, keeps attention tiles in shared memory to avoid repeated round trips to HBM.\nL2 Cache is shared across all SMs. The A100 has 40 MB of L2. It automatically caches HBM accesses with ~5 TB/s bandwidth. While useful, 40 MB is tiny compared to model sizes (a 7B parameter model in FP16 is ~14 GB), so L2 hit rates for LLM workloads are often low.\nHBM (High Bandwidth Memory) is the main GPU memory — what people mean when they say \u0026ldquo;GPU memory\u0026rdquo; or \u0026ldquo;VRAM.\u0026rdquo; The A100 has 80 GB of HBM2e with ~2 TB/s bandwidth. This is where model parameters, activations, gradients, optimizer states, and KV caches live. Despite the name \u0026ldquo;High Bandwidth Memory,\u0026rdquo; HBM is by far the slowest level of the hierarchy, and its bandwidth is the primary bottleneck for LLM workloads.\nKey Insight. There is roughly a 10x bandwidth gap between each level. Registers and shared memory provide ~19 TB/s; L2 provides ~5 TB/s; HBM provides ~2 TB/s. Effective GPU programming is fundamentally about maximizing data reuse at higher levels of the hierarchy to minimize HBM accesses. This is exactly what techniques like FlashAttention and kernel fusion accomplish.\nWhy HBM Bandwidth Is THE Bottleneck for LLMs To understand whether a workload is limited by memory or compute, we use the concept of arithmetic intensity — the ratio of compute operations to bytes transferred from memory.\nArithmetic Intensity = FLOPs / Bytes Accessed The NVIDIA A100 has:\n~312 TFLOPS of BF16 Tensor Core compute ~2 TB/s of HBM bandwidth The ratio of these two numbers gives the ridge point of the Roofline model:\nRidge Point = 312 TFLOPS / 2 TB/s = 156 FLOPs/byte If your operation performs fewer than 156 FLOPs per byte of data it reads from HBM, it is memory-bound — the compute units sit idle waiting for data. If it performs more than 156 FLOPs per byte, it is compute-bound — the memory system can keep up.\nWhere do LLM operations fall?\nOperation Arithmetic Intensity Bound Large GEMM (training) 100-1000+ FLOPs/byte Compute-bound Small GEMM (inference, batch=1) ~1-10 FLOPs/byte Memory-bound Softmax ~5 FLOPs/byte Memory-bound Layer Normalization ~5 FLOPs/byte Memory-bound Attention (naive) ~10 FLOPs/byte Memory-bound Elementwise ops (GELU, residual add) 1 FLOPs/byte Memory-bound During training with large batch sizes, the big matrix multiplications (linear layers, attention projections) can be compute-bound because the matrices are large enough to amortize memory access costs. But many other operations (normalization, softmax, elementwise) remain memory-bound.\nDuring inference, especially autoregressive decoding with batch size 1, almost everything is memory-bound. Each decoding step generates one token, which means the weight matrices must be read from HBM for a single vector-matrix multiply — catastrophically low arithmetic intensity. This is why inference optimization focuses so heavily on:\nBatching — processing more tokens per HBM read (continuous batching, dynamic batching) Quantization — reducing bytes per parameter (INT8, INT4, FP8) KV cache management — avoiding redundant computation/memory for prefix tokens Kernel fusion — combining multiple memory-bound operations into one kernel to reduce HBM round trips Hands-On: GPU Memory Profiling Theory becomes concrete when you measure it. The companion script memory_profiling.py profiles GPU memory usage through a complete training step of a Transformer model. Here is what it reveals.\nThe Setup. A 6-layer TransformerEncoder with d_model=512, 8 attention heads, and dim_feedforward=2048. This gives ~18.9M parameters, roughly 75.6 MB at FP32.\n1 2 3 4 5 6 7 8 9 10 11 12 # code/01-gpu-memory-distributed/memory_profiling.py (excerpt) D_MODEL = 512 N_HEADS = 8 N_LAYERS = 6 DIM_FFN = 2048 encoder_layer = nn.TransformerEncoderLayer( d_model=D_MODEL, nhead=N_HEADS, dim_feedforward=DIM_FFN, batch_first=True, ) model = nn.TransformerEncoder(encoder_layer, num_layers=N_LAYERS).to(device) Memory at each stage. The script tracks torch.cuda.memory_allocated() at each step:\nStage Allocated (MB) Delta (MB) ───────────────────────────────────────────────────────────────── 1. Baseline (empty GPU) 0.00 +0.00 2. Model loaded to GPU 75.60 +75.60 ← Parameters 3. Input tensors created 83.60 +8.00 ← Input data 4. After forward pass 155.00 +71.40 ← Activations 5. After backward pass 151.20 -3.80 ← Gradients created, some activations freed 6. After optimizer.step() 302.40 +151.20 ← Adam state (2x params) 7. After zero_grad(set_to_none=True) 226.80 -75.60 ← Gradients freed (Exact numbers depend on your hardware and PyTorch version. Run the script to see your own results.)\nThe 4x Rule with Adam. The most important takeaway is the memory multiplier for training with the Adam optimizer:\nComponent Size Multiplier ─────────────────────────────────────────────────── Parameters P bytes 1x Gradients P bytes 1x Adam momentum (m) P bytes 1x Adam variance (v) P bytes 1x ─────────────────────────────────────────────────── Total 4P bytes 4x For a 7B parameter model in FP32, that is 7B * 4 bytes * 4 = 112 GB just for parameters and optimizer state — already exceeding a single A100\u0026rsquo;s 80 GB. This is the fundamental reason why distributed training (and techniques like FSDP that shard optimizer state) exists.\nBut wait — we haven\u0026rsquo;t even counted activations yet. Activation memory scales with batch_size * seq_len * hidden_dim * num_layers and can easily exceed parameter memory during training. Techniques like gradient checkpointing (recomputing activations during backward instead of storing them) trade compute for memory here.\nMixed precision changes the arithmetic. With BF16 parameters and FP32 optimizer state (the standard mixed-precision recipe):\nComponent Bytes per Param ────────────────────────────────────── BF16 parameters 2 BF16 gradients 2 FP32 master weights 4 (for numerical stability) FP32 momentum 4 FP32 variance 4 ────────────────────────────────────── Total 16 bytes per parameter For a 70B model: 70B * 16 bytes = 1.12 TB. This is why large-scale training requires hundreds of GPUs.\nTry it yourself. Run python memory_profiling.py on any CUDA-capable GPU. Modify N_LAYERS, BATCH_SIZE, and SEQ_LEN to see how each factor affects memory usage. The full code is at code/01-gpu-memory-distributed/memory_profiling.py.\nPart B: Distributed Communication Why Multi-GPU? The arithmetic from Part A makes the case clearly: a 70B parameter model needs ~140 GB in FP16 just for the weights — already exceeding the 80 GB capacity of an A100. Add optimizer state, gradients, activations, and KV caches, and you need multiple GPUs even for inference, let alone training.\nBut distributing computation across GPUs introduces a new challenge: communication. Every parallelism strategy (DDP, FSDP, Tensor Parallel, Pipeline Parallel, Expert Parallel) is defined by what data it communicates, when, and how. All of these strategies are built on a small set of collective communication primitives provided by NVIDIA\u0026rsquo;s NCCL library.\nNCCL Primitives: Derived from Training Scenarios The best way to understand communication primitives is not to memorize definitions, but to start from real training scenarios and see what data movement each one naturally requires. NCCL (NVIDIA Collective Communications Library, pronounced \u0026ldquo;nickel\u0026rdquo;) provides 8 communication primitives, each corresponding to a concrete engineering need.\nThe companion script nccl_allreduce.py demonstrates all eight primitives with concrete before/after values. Run it with:\n1 torchrun --nproc_per_node=2 nccl_allreduce.py Below we assume 4 GPUs (Rank 0-3, $P = 4$) and data size $N$, and derive the primitives from 5 scenarios.\nScenario 1: DDP — Each GPU Has Different Data, How to Sync Gradients? DDP (DistributedDataParallel) is the simplest data-parallel strategy: every GPU holds a full copy of the model, processes a different mini-batch, and then synchronizes gradients so parameter updates stay identical.\nThe requirement is clear: each GPU has computed local gradients $g_i$, and we need every GPU to end up with $\\bar{g} = \\frac{1}{P}\\sum_i g_i$. This is exactly the definition of All-Reduce.\nAll-Reduce What it does. Reduce (typically SUM) across all GPUs, result delivered to every GPU.\nBefore: GPU 0: [a0, a1, a2, a3] GPU 1: [b0, b1, b2, b3] GPU 2: [c0, c1, c2, c3] GPU 3: [d0, d1, d2, d3] After: Every GPU: [Σ0, Σ1, Σ2, Σ3] where Σi = ai + bi + ci + di 1 2 3 # PyTorch DDP uses this under the hood: dist.all_reduce(gradient_tensor, op=dist.ReduceOp.SUM) gradient_tensor /= world_size # average Communication volume. Each GPU sends and receives ~$2N \\cdot \\frac{P-1}{P}$ bytes (Ring algorithm). For large $P$ this approaches $2N$ — virtually independent of GPU count. This is the elegance of the Ring algorithm.\nScenario 2: FSDP — Parameters Are Sharded, How to Compute Forward/Backward? DDP\u0026rsquo;s problem is that every GPU stores the full model — a 70B model simply does not fit. FSDP (Fully Sharded Data Parallel) shards parameters, gradients, and optimizer states across GPUs, each storing only $1/P$.\nBut how do you compute with sharded parameters?\nForward pass: Computing a layer requires its full parameters. Each GPU only has one shard, so they must temporarily reassemble all shards → All-Gather Backward pass: Each GPU computes full gradients, then needs to reduce and re-shard so each GPU only keeps its own slice → Reduce-Scatter All-Gather What it does. Each GPU contributes its shard; the concatenated full result is delivered to every GPU.\nBefore: GPU 0: [a0] GPU 1: [b1] GPU 2: [c2] GPU 3: [d3] After: Every GPU: [a0, b1, c2, d3] Communication volume. Each GPU receives $N \\cdot \\frac{P-1}{P}$ data.\nReduce-Scatter What it does. Reduce (sum) across all GPUs, then scatter different chunks to different GPUs — GPU i gets the i-th chunk.\nBefore: GPU 0: [a0, a1, a2, a3] GPU 1: [b0, b1, b2, b3] GPU 2: [c0, c1, c2, c3] GPU 3: [d0, d1, d2, d3] After: GPU 0: [Σ0] GPU 1: [Σ1] GPU 2: [Σ2] GPU 3: [Σ3] Communication volume. Each GPU sends $N \\cdot \\frac{P-1}{P}$ data.\nThe FSDP Communication Loop Forward: All-Gather → reassemble full params → compute → discard full params Backward: All-Gather → reassemble full params → compute gradients → Reduce-Scatter → keep only gradient shard Update: Each GPU updates its parameter shard with its gradient shard Key insight. All-Reduce is conceptually equivalent to Reduce-Scatter + All-Gather. DDP uses All-Reduce because every GPU needs the full gradient; FSDP uses Reduce-Scatter because each GPU only needs its own shard. NCCL internally often decomposes All-Reduce into these two steps.\nScenario 3: Pipeline Parallel — Model Split by Layers, How to Pass Activations? Pipeline Parallelism (PP) partitions the model into stages by layer, each stage on a different GPU. During forward, stage 0 must send its output activations to stage 1; during backward, stage 1 must send gradients back to stage 0.\nThis does not require collective communication — it is just direct data transfer between two GPUs → Send / Recv.\nSend / Recv (Point-to-Point) What it does. Direct communication between exactly two GPUs.\nForward: GPU 0 ──send(activations)──→ GPU 1 Backward: GPU 0 ←──recv(gradients)──── GPU 1 Communication volume. $O(N)$, involving only two GPUs.\nSend/Recv is the only non-collective primitive. Pipeline scheduling algorithms (1F1B, Zero Bubble, etc.) are essentially careful orchestrations of these Send/Recv operations, timing them so different stages stay busy and pipeline bubbles are minimized.\nScenario 4: MoE Expert Parallel — Tokens Routed to Different Experts, How to Move Them? In MoE (Mixture-of-Experts) models, a gating network routes each token to one or more experts. Under Expert Parallelism (EP), different experts live on different GPUs.\nThe problem: tokens on any GPU may need to go to experts on any other GPU. This is not one-to-many or many-to-one — it is every GPU sending different data to every other GPU → All-to-All.\nAll-to-All What it does. Every GPU sends a different chunk to every other GPU and receives a different chunk from each.\nBefore: GPU 0: [a→0, a→1, a→2, a→3] GPU 1: [b→0, b→1, b→2, b→3] GPU 2: [c→0, c→1, c→2, c→3] GPU 3: [d→0, d→1, d→2, d→3] After: GPU 0: [a→0, b→0, c→0, d→0] GPU 1: [a→1, b→1, c→1, d→1] GPU 2: [a→2, b→2, c→2, d→2] GPU 3: [a→3, b→3, c→3, d→3] An MoE layer\u0026rsquo;s communication pattern: All-to-All (dispatch: tokens → experts) → expert computation → All-to-All (combine: expert outputs → original GPUs) — two All-to-All operations, one before and one after.\nCommunication volume. Each GPU sends and receives $N \\cdot \\frac{P-1}{P}$ data. All-to-All is the heaviest collective because it requires full-mesh communication, making MoE training particularly sensitive to network topology and bandwidth.\nScenario 5: Initialization and Data Flow — \u0026ldquo;Glue\u0026rdquo; Primitives The four scenarios above cover core training communication needs. A few more primitives handle initialization and data management:\nBroadcast What it does. One GPU (the \u0026ldquo;root\u0026rdquo;) sends its data to all other GPUs.\nBefore: GPU 0: [A, A, A, A] GPU 1: [., ., ., .] GPU 2: [., ., ., .] GPU 3: [., ., ., .] After: Every GPU: [A, A, A, A] Typical use. Before training begins, Rank 0 initializes model parameters and Broadcasts them so all GPUs start identical. DDP internally calls Broadcast during setup.\nCommunication volume. Each GPU sends or receives $O(N)$ data. Global total depends on implementation: naive (root sends one by one) costs $O(N \\cdot P)$; in practice, NCCL uses a tree-based broadcast where multiple GPUs relay in parallel, completing in $\\log P$ steps for a global total of $O(N \\log P)$.\nPerspective Volume Per-GPU (send or receive) $O(N)$ Global (naive) $O(N \\cdot P)$ Global (tree broadcast) $O(N \\log P)$ Scatter / Gather Scatter distributes different data chunks from one GPU to all GPUs. Gather collects data from all GPUs onto one (the inverse of Scatter).\nScatter (from GPU 0): GPU 0: [d0, d1, d2, d3] → GPU 0: [d0] GPU 1: [d1] GPU 2: [d2] GPU 3: [d3] Gather (to GPU 0): GPU 0: [d0] GPU 1: [d1] GPU 2: [d2] GPU 3: [d3] → GPU 0: [d0, d1, d2, d3] Typical use. During data loading, Rank 0 reads a large batch and Scatters it across GPUs; during evaluation, Gather collects predictions onto Rank 0 for aggregation.\nThe Full Picture: From Scenario to Primitive Training Scenario Communication Need Primitive Per-GPU Volume DDP gradient sync Sum gradients, result to all GPUs all_reduce $2N \\cdot \\frac{P-1}{P}$ FSDP forward (reassemble params) Each GPU contributes shard, all get full result all_gather $N \\cdot \\frac{P-1}{P}$ FSDP backward (shard gradients) Reduce gradients, each GPU keeps its shard reduce_scatter $N \\cdot \\frac{P-1}{P}$ Pipeline Parallel Pass activations/gradients between adjacent stages send / recv $N$ Expert Parallel (MoE) Full-permutation token routing all_to_all $N \\cdot \\frac{P-1}{P}$ Parameter initialization Copy one GPU\u0026rsquo;s params to all broadcast $N$ (global $N \\log P$) Data distribution / result collection One-to-many or many-to-one scatter / gather $N \\cdot \\frac{P-1}{P}$ Collective Communication Algorithms The primitives above describe what communication needs to happen. The algorithms determine how the data physically moves through the network. The choice of algorithm affects bandwidth utilization, latency, and scalability.\nRing All-Reduce The Ring algorithm is the most widely used algorithm for All-Reduce and is the default in NCCL for large messages.\nSetup. Arrange N GPUs in a logical ring: GPU 0 → GPU 1 → \u0026hellip; → GPU N-1 → GPU 0.\nPhase 1: Reduce-Scatter. Each GPU splits its data into N chunks. Over N-1 steps, chunks are passed around the ring and reduced (summed) along the way. After this phase, each GPU holds the fully reduced version of one chunk.\nPhase 2: All-Gather. The reduced chunks are passed around the ring again for N-1 steps. After this phase, every GPU has the complete reduced result.\nRing All-Reduce with 4 GPUs (simplified): Step 0: GPU0:[a0,a1,a2,a3] GPU1:[b0,b1,b2,b3] GPU2:[c0,c1,c2,c3] GPU3:[d0,d1,d2,d3] Each GPU splits data into 4 chunks Phase 1 (Reduce-Scatter) — 3 steps: Step 1: GPU0 sends a0→GPU1, GPU1 sends b1→GPU2, GPU2 sends c2→GPU3, GPU3 sends d3→GPU0 Recipients sum the received chunk with their own Step 2: Continue rotating and summing... Step 3: Continue rotating and summing... Result: GPU0 has sum[*3], GPU1 has sum[*0], GPU2 has sum[*1], GPU3 has sum[*2] Phase 2 (All-Gather) — 3 steps: Rotate the reduced chunks around the ring. Result: Every GPU has [sum0, sum1, sum2, sum3] Bandwidth analysis. Each GPU sends and receives 2 * (N-1)/N * data_size bytes total across both phases. As N grows, this approaches 2 * data_size. The critical insight is that the total per-GPU communication volume is independent of the number of GPUs — it scales with data size, not GPU count. This makes Ring All-Reduce highly scalable for large messages.\nLatency. The ring requires 2 * (N-1) sequential steps, so latency grows linearly with N. For very large GPU counts, this becomes a problem.\nTree All-Reduce Tree All-Reduce organizes GPUs in a binary tree. Data is reduced up the tree (children → parent) and then broadcast down (parent → children).\ngraph TD R0[\u0026#34;GPU 0 (root)\u0026#34;] R1[\u0026#34;GPU 1\u0026#34;] R2[\u0026#34;GPU 2\u0026#34;] R3[\u0026#34;GPU 3\u0026#34;] R4[\u0026#34;GPU 4\u0026#34;] R5[\u0026#34;GPU 5\u0026#34;] R6[\u0026#34;GPU 6\u0026#34;] R0 --- R1 \u0026amp; R2 R1 --- R3 \u0026amp; R4 R2 --- R5 \u0026amp; R6 style R0 fill:#d4a574,color:#000 Latency. Tree All-Reduce completes in O(log N) steps — much better than Ring\u0026rsquo;s O(N) for large GPU counts.\nBandwidth. However, bandwidth utilization is worse: the root node becomes a bottleneck because it must receive data from all children and send the result back down. Only the leaves can fully utilize their bandwidth.\nWhen to use it. Tree All-Reduce is preferred for small messages where latency dominates, or for very large GPU counts. NCCL automatically chooses between Ring and Tree based on message size and topology.\nRecursive Halving-Doubling This algorithm combines the best of Ring (bandwidth) and Tree (latency). It works in two phases:\nRecursive Halving (Reduce-Scatter). Pairs of GPUs exchange half their data and reduce. In each step, the active group halves and the data chunk doubles, achieving both O(log N) latency and near-optimal bandwidth.\nRecursive Doubling (All-Gather). The reverse process: groups double in size, exchanging reduced chunks until every GPU has the full result.\nComplexity. Both phases take log N steps. Total communication volume per GPU is 2 * (N-1)/N * data_size — the same as Ring. But it achieves this in O(log N) latency instead of O(N).\nTrade-off. Recursive Halving-Doubling requires more complex routing and works best when N is a power of 2. NCCL uses variants of this algorithm internally.\nAlgorithm Selection in Practice NCCL automatically selects the best algorithm based on:\nMessage size: Small messages → Tree (latency-sensitive); Large messages → Ring (bandwidth-sensitive) GPU count: More GPUs favor lower-latency algorithms Topology: NVLink vs PCIe vs cross-node affects optimal chunk sizes You rarely need to choose algorithms manually, but understanding them helps you reason about performance. When profiling distributed training, the key question is: \u0026ldquo;Is my communication bandwidth-limited or latency-limited?\u0026rdquo; The answer determines which algorithm (and which interconnect) matters.\nCommunication Topology: The Hardware Layer The choice of communication algorithm is constrained by the physical interconnects between GPUs. The interconnect determines the raw bandwidth available for each link, and the topology determines how many hops a message must traverse.\nNVLink NVLink is NVIDIA\u0026rsquo;s high-speed GPU-to-GPU interconnect. On the A100, each NVLink connection provides 25 GB/s per direction (50 GB/s bidirectional), and each GPU has 12 NVLink connections, totaling 600 GB/s bidirectional bandwidth.\nNVLink is a direct point-to-point connection between two GPUs. Within a single server (e.g., a DGX A100 with 8 GPUs), all GPUs are interconnected via NVLink, enabling All-Reduce at nearly the full 600 GB/s aggregate bandwidth.\nNVSwitch In a DGX A100, the 8 GPUs are connected through NVSwitch — a switch fabric that provides full bisection bandwidth. Any GPU can communicate with any other GPU at the full NVLink rate without contention. This means that collective operations within a single node are extremely fast.\nflowchart LR subgraph DGX[\u0026#34;DGX A100 — Any GPU pair: 600 GB/s bidirectional\u0026#34;] direction LR subgraph LEFT[\u0026#34; \u0026#34;] direction TB G0[\u0026#34;GPU 0\u0026#34;] ~~~ G1[\u0026#34;GPU 1\u0026#34;] ~~~ G2[\u0026#34;GPU 2\u0026#34;] ~~~ G3[\u0026#34;GPU 3\u0026#34;] end NVS[\u0026#34;NVSwitch\\n× 6\u0026#34;] subgraph RIGHT[\u0026#34; \u0026#34;] direction TB G4[\u0026#34;GPU 4\u0026#34;] ~~~ G5[\u0026#34;GPU 5\u0026#34;] ~~~ G6[\u0026#34;GPU 6\u0026#34;] ~~~ G7[\u0026#34;GPU 7\u0026#34;] end LEFT \u0026lt;--\u0026gt;|\u0026#34;NVLink\u0026#34;| NVS \u0026lt;--\u0026gt;|\u0026#34;NVLink\u0026#34;| RIGHT end PCIe PCIe (Peripheral Component Interconnect Express) is the standard bus connecting GPUs to the CPU and to each other in consumer and some server setups. PCIe 4.0 x16 provides ~32 GB/s bidirectional — roughly 20x slower than NVLink. PCIe 5.0 doubles this to ~64 GB/s, but still far below NVLink speeds.\nIn systems without NVLink (e.g., consumer GPUs, some cloud instances), GPU-to-GPU communication must go through PCIe, often via the CPU (PCIe → CPU → PCIe), further increasing latency. This is why high-end training clusters always use NVLink.\nRDMA and InfiniBand (Cross-Node Communication) Within a single server, NVLink handles GPU communication. But large training runs span hundreds or thousands of GPUs across many servers. Cross-node communication uses the network:\nInfiniBand (IB) is the dominant high-performance network for GPU clusters. A single HDR InfiniBand link provides 200 Gbps (~25 GB/s), and servers typically have 4-8 IB links, giving 100-200 GB/s per node. RDMA (Remote Direct Memory Access) allows GPUs to read/write memory on remote GPUs without involving the CPU. NVIDIA\u0026rsquo;s GPUDirect RDMA enables direct NIC-to-GPU transfers, bypassing CPU memory entirely. This minimizes latency and maximizes bandwidth for cross-node communication. The bandwidth hierarchy:\nNVLink (intra-node): 600 GB/s bidirectional (A100) InfiniBand (inter-node): 100-200 GB/s per node PCIe (fallback): 32-64 GB/s NVLink is 3-6x faster than InfiniBand InfiniBand is 2-5x faster than PCIe This hierarchy is why distributed training systems are designed with topology awareness:\nOperations within a node (e.g., Tensor Parallel) use NVLink Operations across nodes (e.g., Data Parallel, Pipeline Parallel) use InfiniBand The most communication-intensive strategies (TP) are always placed intra-node The H100 and B200 generations further increase NVLink bandwidth (900 GB/s and 1.8 TB/s respectively) and introduce NVLink Network for multi-node NVLink, blurring the intra/inter-node boundary.\nSummary and What\u0026rsquo;s Next This article established the two foundational pillars of LLM infrastructure:\nMemory. The GPU memory hierarchy (Registers → Shared Memory → L2 → HBM) creates a bandwidth pyramid where HBM, despite its name, is the bottleneck at ~2 TB/s. LLM workloads, especially inference, are overwhelmingly memory-bound. The 4x memory multiplier with Adam (parameters + gradients + 2 optimizer states) means a 70B model needs over 1 TB for training — far exceeding any single GPU.\nCommunication. Eight NCCL primitives (All-Reduce, All-Gather, Reduce-Scatter, Broadcast, Scatter, Gather, Send/Recv, All-to-All) form the vocabulary of distributed training. Each primitive maps to specific training strategies: All-Reduce for DDP, All-Gather and Reduce-Scatter for FSDP, Send/Recv for Pipeline Parallel, All-to-All for Expert Parallel. Ring, Tree, and Recursive Halving-Doubling algorithms implement these primitives, with NCCL selecting the best algorithm automatically based on message size and topology.\nIn the next article, we will build on this foundation to explore the full landscape of distributed parallelism strategies: DDP, FSDP/FSDP2, Tensor Parallelism, Pipeline Parallelism, Sequence Parallelism, Expert Parallelism, and Context Parallelism. Every strategy is a specific answer to the question: \u0026ldquo;How do we split memory and coordinate communication across GPUs?\u0026rdquo; With the primitives from this article, you will be able to understand each strategy at the protocol level, not just the high-level concept.\nReferences NVIDIA A100 Tensor Core GPU Architecture Whitepaper. NVIDIA, 2020. Jia, Z., Maggioni, M., Staiger, B., \u0026amp; Scarpazza, D. P. \u0026ldquo;Dissecting the NVIDIA Volta GPU Architecture via Microbenchmarking.\u0026rdquo; arXiv:1804.06826, 2018. Williams, S., Waterman, A., \u0026amp; Patterson, D. \u0026ldquo;Roofline: An Insightful Visual Performance Model for Multicore Architectures.\u0026rdquo; Communications of the ACM, 52(4):65-76, 2009. NCCL Documentation. NVIDIA. https://docs.nvidia.com/deeplearning/nccl/ Thakur, R., Rabenseifner, R., \u0026amp; Gropp, W. \u0026ldquo;Optimization of Collective Communication Operations in MPICH.\u0026rdquo; International Journal of High Performance Computing Applications, 19(1):49-66, 2005. Patarasuk, P. \u0026amp; Yuan, X. \u0026ldquo;Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations.\u0026rdquo; Journal of Parallel and Distributed Computing, 69(2):117-124, 2009. Rajbhandari, S., Rasley, J., Ruwase, O., \u0026amp; He, Y. \u0026ldquo;ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.\u0026rdquo; SC20, 2020. Dao, T., Fu, D. Y., Ermon, S., Rudra, A., \u0026amp; Ré, C. \u0026ldquo;FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.\u0026rdquo; NeurIPS, 2022. NVIDIA DGX A100 System Architecture Whitepaper. NVIDIA, 2020. Li, S., Zhao, Y., Varma, R., et al. \u0026ldquo;PyTorch Distributed: Experiences on Accelerating Data Parallel Training.\u0026rdquo; VLDB, 2020. ","permalink":"https://mzf666.github.io/llm-infra/en/posts/01-gpu-memory-distributed/","summary":"From GPU memory hierarchy to NCCL communication primitives — the two pillars of LLM Infra optimization.","title":"GPU Memory Model and Distributed Communication Fundamentals"},{"content":"Motivation In the previous article, we did some back-of-the-envelope arithmetic: a 70B parameter model requires $16 \\times 70 \\times 10^9 \\approx 1120$ GB just for Adam optimizer-related memory — at least 14 A100 80GB GPUs, and that is before we even count activations. Training a large model on a single GPU is physically impossible. We must distribute computation across multiple devices.\nBut \u0026ldquo;distribute across GPUs\u0026rdquo; is far easier said than done. Different parallelism strategies make fundamentally different choices about what to split (parameters, gradients, activations, sequences), how to split it (by layer, by dimension, by data), and what communication cost to pay (all_reduce, all_gather, send/recv, all_to_all). Choose the wrong strategy and you waste compute at best, or fail to run at all at worst.\nThis article starts from the simplest strategy — DDP — and progressively covers FSDP, Tensor Parallelism, Pipeline Parallelism, Sequence Parallelism, Expert Parallelism, and the cutting-edge approaches of Context Parallelism and hybrid parallelism. By the end, you should be able to look at any model size, hardware topology, and training requirement, and assemble a sensible combination of parallelism strategies.\nPrerequisites GPU Memory Model and Distributed Communication Fundamentals (Article 1) — especially NCCL primitives and hardware topology Basic PyTorch DDP experience (you have written a training script launched with torchrun) Familiarity with the Transformer architecture: Self-Attention, FFN, LayerNorm Here is the big picture. We will unpack each piece in the sections that follow:\nflowchart TD subgraph DP[\u0026#34;Data Parallel\u0026#34;] DDP[\u0026#34;DDP\u0026#34;] FSDP[\u0026#34;FSDP / ZeRO\u0026#34;] end subgraph MP[\u0026#34;Model Parallel\u0026#34;] TP[\u0026#34;TP (Tensor)\u0026#34;] SP[\u0026#34;SP (Sequence)\u0026#34;] end subgraph PP[\u0026#34;Pipeline Parallel\u0026#34;] PP1[\u0026#34;1F1B\u0026#34;] PP2[\u0026#34;Zero Bubble\u0026#34;] end subgraph EP[\u0026#34;Expert Parallel\u0026#34;] EP1[\u0026#34;EP (MoE)\u0026#34;] EP2[\u0026#34;DeepSeek\u0026#34;] end DP \u0026amp; MP \u0026amp; PP \u0026amp; EP --\u0026gt; HYBRID[\u0026#34;Hybrid Parallelism (3D/5D) TP intra-node + FSDP cross-node + PP cross-node-group\u0026#34;] Classic Parallelism Strategies DDP (Distributed Data Parallel) DDP is the simplest and most widely used parallelism strategy. The core idea fits in three words: replicate model, split data.\nHow it works:\nEvery GPU holds a complete copy of the model — parameters, gradients, and optimizer state all fully replicated. Training data is partitioned across GPUs via DistributedSampler — each GPU sees only its own mini-batch. Each GPU independently performs forward and backward passes, computing its own local gradients. An all_reduce aggregates (sums or averages) the gradients so that every GPU ends up with identical aggregated gradients. Each GPU executes the same optimizer update, keeping parameters in sync. flowchart TD DATA[\u0026#34;Data Batch\u0026#34;] --\u0026gt;|DistributedSampler| B0[\u0026#34;B₀\u0026#34;] \u0026amp; B1[\u0026#34;B₁\u0026#34;] \u0026amp; B2[\u0026#34;B₂\u0026#34;] \u0026amp; B3[\u0026#34;B₃\u0026#34;] B0 --\u0026gt; G0[\u0026#34;GPU 0 Full model → grad₀\u0026#34;] B1 --\u0026gt; G1[\u0026#34;GPU 1 Full model → grad₁\u0026#34;] B2 --\u0026gt; G2[\u0026#34;GPU 2 Full model → grad₂\u0026#34;] B3 --\u0026gt; G3[\u0026#34;GPU 3 Full model → grad₃\u0026#34;] G0 \u0026amp; G1 \u0026amp; G2 \u0026amp; G3 --\u0026gt; AR[\u0026#34;all_reduce(gradients) Only communication op\u0026#34;] AR --\u0026gt; OPT[\u0026#34;optimizer.step() Identical update on all GPUs\u0026#34;] Memory. DDP does not save any memory. Each GPU independently stores the full parameters ($2\\Phi$ bytes in FP16), gradients ($2\\Phi$ bytes), and optimizer state ($12\\Phi$ bytes for Adam), totaling $16\\Phi$ bytes per GPU — exactly the same as single-GPU training.\nCommunication. The only communication is an all_reduce on gradients. PyTorch DDP uses an important optimization called Bucketed All-Reduce: instead of waiting for all gradients to be computed, it groups gradients into buckets (default 25 MB). As soon as all gradients in a bucket are ready, their all_reduce begins — overlapping with the backward pass of earlier layers. This significantly hides communication latency.\nWhen to use DDP. Whenever the model fits on a single GPU. DDP is simple, efficient, and scales near-linearly. Its limitation is equally clear: the model must fit entirely in a single GPU\u0026rsquo;s memory.\nHere is the core code:\n1 2 3 4 5 6 7 8 9 10 # DDP wrapping — registers gradient synchronization hooks model = DDP(model, device_ids=[local_rank]) # The training loop is identical to single-GPU training for input_ids, target_ids in dataloader: logits = model(input_ids) loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1)) optimizer.zero_grad() loss.backward() # ← DDP hook automatically triggers all_reduce here optimizer.step() # ← Gradients are identical → parameters stay in sync Full code: code/02-parallel-strategies/ddp_example.py\nFSDP / FSDP2 (Fully Sharded Data Parallel) The fatal limitation of DDP is that every GPU stores a complete copy of parameters, gradients, and optimizer state. When the model is too large for a single GPU, DDP is out of options.\nFSDP (Fully Sharded Data Parallel) draws its core idea from the DeepSpeed ZeRO (Zero Redundancy Optimizer) series of papers: since every GPU holds redundant copies of parameters and optimizer state, why not shard them across GPUs and reconstruct them temporarily when needed?\nZeRO Stages 1, 2, and 3 ZeRO organizes memory optimization into three stages, each sharding progressively more:\nZeRO Stage What is Sharded Per-GPU Memory FSDP Equivalent Stage 1 Optimizer state $4\\Phi + \\frac{12\\Phi}{N}$ — Stage 2 Optimizer state + gradients $2\\Phi + \\frac{14\\Phi}{N}$ SHARD_GRAD_OP Stage 3 Optimizer state + gradients + parameters $\\frac{16\\Phi}{N}$ FULL_SHARD Here $\\Phi$ is the parameter count and $N$ is the number of GPUs. The key observation is that Stage 3 achieves near-linear memory reduction — 8 GPUs cut per-GPU memory to roughly 1/8th (plus activation overhead).\nFSDP Forward and Backward Passes Here is how FSDP (FULL_SHARD) works in practice:\nflowchart LR subgraph FWD[\u0026#34;FSDP Forward Pass — Layer i\u0026#34;] direction LR FS1[\u0026#34;Shard 1/N\u0026#34;] --\u0026gt;|\u0026#34;all_gather\u0026#34;| FP1[\u0026#34;Full Params (temporary)\u0026#34;] FP1 --\u0026gt;|\u0026#34;compute\u0026#34;| FO1[\u0026#34;Output\u0026#34;] FP1 -.-\u0026gt;|\u0026#34;discard (N-1)/N\u0026#34;| X1[\u0026#34; \u0026#34;] end subgraph BWD[\u0026#34;FSDP Backward Pass — Layer i\u0026#34;] direction LR BS1[\u0026#34;Shard 1/N\u0026#34;] --\u0026gt;|\u0026#34;all_gather\u0026#34;| BP1[\u0026#34;Full W (temporary)\u0026#34;] BP1 --\u0026gt;|\u0026#34;backward\u0026#34;| BG1[\u0026#34;Grad (full)\u0026#34;] BG1 --\u0026gt;|\u0026#34;reduce_scatter\u0026#34;| BGS[\u0026#34;Grad Shard 1/N\u0026#34;] end FWD ~~~ BWD style X1 fill:none,stroke:none Communication cost comparison:\nStrategy Forward Communication Backward Communication Total Volume DDP None 1x all_reduce (gradients) $2\\Phi$ FSDP (FULL_SHARD) all_gather (parameters) all_gather (parameters) + reduce_scatter (gradients) $3\\Phi$ FSDP communicates roughly 1.5x as much as DDP — that is the cost of trading communication for memory. But in the large-model regime, this trade-off is absolutely worth it: without FSDP, you simply cannot run the model at all.\nFSDP2: The Composable API in PyTorch 2.2+ PyTorch 2.2 and later introduced FSDP2 (torch.distributed._composable.fsdp), which brings several improvements over the original FSDP:\nPer-parameter sharding: instead of requiring entire nn.Module subtrees as FSDP units, FSDP2 can shard individual parameters. Composability: FSDP2 composes freely with Tensor Parallel, Pipeline Parallel, and other strategies without the awkward nested wrapping that FSDP1 required. Flexible sharding granularity: different layers can use different sharding strategies. The core concepts are identical to FSDP1 — the all_gather/reduce_scatter communication pattern is unchanged. What is new is a more modern API that integrates better with the PyTorch 2.x compiler stack.\nHere is a code snippet showing how to compare FSDP sharding strategies:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy # Compare three strategies strategies = [ (ShardingStrategy.NO_SHARD, \u0026#34;NO_SHARD (= DDP)\u0026#34;), (ShardingStrategy.SHARD_GRAD_OP, \u0026#34;SHARD_GRAD_OP (= ZeRO-2)\u0026#34;), (ShardingStrategy.FULL_SHARD, \u0026#34;FULL_SHARD (= ZeRO-3)\u0026#34;), ] for strategy, name in strategies: model = FSDP( TransformerLM(), sharding_strategy=strategy, device_id=local_rank, ) # ... train and measure memory Full code: code/02-parallel-strategies/fsdp_example.py\nTP (Tensor Parallel) — Megatron Column/Row Partitioning FSDP shards parameters, then reassembles them — communication and computation are sequential. Tensor Parallel takes a more radical approach: it slices each layer\u0026rsquo;s weight matrix along a dimension and distributes the pieces to different GPUs, each of which computes part of the result. This directly reduces per-GPU compute and memory.\nThe core ideas come from Megatron-LM, which defines two fundamental parallel linear layer types.\nColumnParallelLinear: Splitting Along the Output Dimension For a linear layer $Y = XW + b$, we partition the weight $W \\in \\mathbb{R}^{d \\times h}$ by columns:\nFull weight W: (d_model × dim_ffn) Split by columns across 2 GPUs: GPU 0: W_0 = W[:, :dim_ffn//2] → Y_0 = X @ W_0 GPU 1: W_1 = W[:, dim_ffn//2:] → Y_1 = X @ W_1 Each GPU computes independently — no communication needed! (because X is identical on all GPUs) Key advantage: the forward pass requires zero communication. Each GPU produces a chunk of the output, which can be fed directly into elementwise operations (such as GeLU) without any synchronization.\nRowParallelLinear: Splitting Along the Input Dimension We partition the weight $W \\in \\mathbb{R}^{h \\times d}$ by rows:\nSplit by rows across 2 GPUs: GPU 0: W_0 = W[:dim_ffn//2, :] → Y_0 = X_0 @ W_0 (partial sum) GPU 1: W_1 = W[dim_ffn//2:, :] → Y_1 = X_1 @ W_1 (partial sum) Final result: Y = Y_0 + Y_1 ← requires all_reduce! Key constraint: each GPU only computes a partial sum of the output. An all_reduce is required to produce the complete result.\nMegatron FFN: Column + Row = Only 1 All-Reduce The elegance of Megatron-LM lies in pairing Column and Row parallel layers:\nflowchart TD X[\u0026#34;Input X (identical on each GPU)\u0026#34;] X --\u0026gt; COL[\u0026#34;ColumnParallelLinear (W1 column-split, no comm)\u0026#34;] COL --\u0026gt; GELU[\u0026#34;GeLU (elementwise, no comm)\u0026#34;] GELU --\u0026gt; ROW[\u0026#34;RowParallelLinear (W2 row-split)\u0026#34;] ROW --\u0026gt;|\u0026#34;all_reduce merge partial sums\u0026#34;| Y[\u0026#34;Output Y (identical on each GPU)\u0026#34;] style COL fill:#d4edda style GELU fill:#d4edda style ROW fill:#fff3cd Only 1 all_reduce for the entire FFN block.\nTensor Parallel for Attention Multi-Head Attention maps to TP equally elegantly:\nQ, K, V projections: use ColumnParallel, distributing attention heads across GPUs — each GPU handles $\\frac{n_heads}{TP}$ heads. Output projection: use RowParallel, merging head outputs across GPUs. This means the entire Attention block also needs only 1 all_reduce. A full Transformer layer therefore requires 2 all_reduces total (1 for FFN + 1 for Attention).\nWhen TP Works Well (and When It Does Not) TP communicates frequently — 2 all_reduces per layer — so it demands very high bandwidth. In practice:\nTP degree is typically 2, 4, or 8, deployed within a single node (NVLink at 600 GB/s). Running TP across nodes is almost never viable (InfiniBand at 25-50 GB/s is too slow for per-layer all_reduce). The hidden dimension and head count must be divisible by the TP degree. Here is the core implementation of ColumnParallelLinear:\n1 2 3 4 5 6 7 8 9 10 11 12 13 class ColumnParallelLinear(nn.Module): def __init__(self, in_features, out_features, world_size, rank): super().__init__() self.out_features_per_rank = out_features // world_size # Each GPU stores only 1/world_size of the weight self.weight = nn.Parameter( torch.empty(self.out_features_per_rank, in_features) ) self.bias = nn.Parameter(torch.empty(self.out_features_per_rank)) def forward(self, x): # No communication! Each GPU computes independently return F.linear(x, self.weight, self.bias) Full code: code/02-parallel-strategies/tensor_parallel.py\nPP (Pipeline Parallel) — 1F1B and Zero Bubble Where TP splits individual layers by dimension, Pipeline Parallel takes the opposite approach: it splits the model by layers, assigning different groups of layers to different GPUs (called stages).\nflowchart LR S0[\u0026#34;Stage 0 (GPU 0)\\nLayer 0-10\u0026#34;] --\u0026gt;|\u0026#34;send/recv\\nactivations\u0026#34;| S1[\u0026#34;Stage 1 (GPU 1)\\nLayer 11-21\u0026#34;] S1 --\u0026gt;|\u0026#34;send/recv\\nactivations\u0026#34;| S2[\u0026#34;Stage 2 (GPU 2)\\nLayer 22-31\u0026#34;] Communication: send/recv (point-to-point, adjacent stages only).\nCommunication. PP only needs send/recv between adjacent stages — transferring intermediate activations (forward) and gradients (backward). The communication volume is far smaller than TP\u0026rsquo;s all_reduce, making PP well-suited for cross-node deployment.\nBut PP has a critical problem: pipeline bubbles.\nNaive PP: Massive Bubbles In the simplest scheme, the entire batch passes through stages sequentially. While one stage computes, all others sit idle:\nTime → GPU 0: [=====Forward=====] [=====Backward=====] GPU 1: [=====Forward=====] [=====Backward=====] GPU 2: [=====Forward=====] [=====Backward=====] ↑ Massive idle time (bubble) With $P$ stages, the bubble fraction is approximately $\\frac{P-1}{P}$ — 4 stages means 75% of the time is wasted.\nGPipe: Micro-Batching GPipe\u0026rsquo;s solution is to split each mini-batch into $M$ micro-batches, letting them flow through stages like an assembly line:\nTime → (M = 4 micro-batches) GPU 0: [F0][F1][F2][F3] [B3][B2][B1][B0] GPU 1: [F0][F1][F2][F3] [B3][B2][B1][B0] GPU 2: [F0][F1][F2][F3] [B3][B2][B1][B0] GPU 3: [F0][F1][F2][F3][B3][B2][B1][B0] Bubble fraction: (P-1) / (M + P - 1) When M \u0026gt;\u0026gt; P, bubble → 0 The bubble fraction drops from $\\frac{P-1}{P}$ to $\\frac{P-1}{M+P-1}$. However, GPipe completes all forward micro-batches before starting any backward passes, which means it must store activations for all $M$ micro-batches simultaneously — a significant memory cost.\n1F1B Schedule 1F1B (One Forward, One Backward) interleaves forward and backward passes. After completing a micro-batch\u0026rsquo;s forward pass, it starts the backward pass as soon as possible, allowing activation memory to be freed earlier:\nTime → (P=4, M=8) GPU 0: [F0][F1][F2][F3][B0][F4][B1][F5][B2][F6][B3][F7][B4][B5][B6][B7] GPU 1: [F0][F1][F2][B0][F3][B1][F4][B2][F5][B3][F6][B4][F7][B5][B6][B7] GPU 2: [F0][F1][B0][F2][B1][F3][B2][F4][B3][F5][B4][F6][B5][F7]... GPU 3: [F0][B0][F1][B1][F2][B2][F3][B3][F4][B4][F5][B5]... During the steady state (after the warmup phase), each GPU simultaneously holds activations for only a bounded number of micro-batches, dramatically reducing memory compared to GPipe. The bubble ratio remains $\\frac{P-1}{M+P-1}$, but memory usage improves substantially.\nZero Bubble PP Proposed in 2024, Zero Bubble PP further reduces pipeline bubbles. The key insight is to decompose the backward pass into two parts: computing input gradients (B) and computing weight gradients (W). B must be communicated to the previous stage, but W does not require communication and can be scheduled to fill the bubble:\nTime → GPU 0: [F][F][F][B][F][B][W][B][W][B][W] ← W fills what used to be bubble GPU 1: [F][F][B][F][B][W][B][W][B][W] Zero Bubble PP can theoretically reduce the bubble fraction to near zero, at the cost of increased scheduling complexity.\nPP summary:\nSchedule Bubble Fraction Activation Memory Implementation Complexity Naive $(P-1)/P$ Low Low GPipe $(P-1)/(M+P-1)$ High (all micro-batches) Medium 1F1B $(P-1)/(M+P-1)$ Medium (bounded micro-batches) Medium Zero Bubble $\\approx 0$ Medium High SP (Sequence Parallel) — LayerNorm/Dropout on the Sequence Dimension Tensor Parallel splits the weight matrices of linear layers by dimension, but Transformers also contain many elementwise operations that have no weight matrices: LayerNorm, Dropout, and residual connections. Under TP, these operations need to operate on the full hidden dimension, which means every GPU must hold the complete activation tensor — creating redundant activation memory.\nSequence Parallel (SP) addresses this by splitting the activations of these non-TP operations along the sequence dimension instead:\nflowchart TD LN1[\u0026#34;LayerNorm — SP: seq/N\u0026#34;] AG1[\u0026#34;all_gather: full sequence\u0026#34;] ATT[\u0026#34;Attention — TP: split by head\u0026#34;] RS1[\u0026#34;reduce_scatter: split to seq dim\u0026#34;] DR1[\u0026#34;Dropout + Residual — SP: seq/N\u0026#34;] LN2[\u0026#34;LayerNorm — SP\u0026#34;] AG2[\u0026#34;all_gather: full sequence\u0026#34;] FFN[\u0026#34;FFN — TP: Column + Row\u0026#34;] RS2[\u0026#34;reduce_scatter: back to seq dim\u0026#34;] DR2[\u0026#34;Dropout + Residual — SP\u0026#34;] LN1 --\u0026gt; AG1 --\u0026gt; ATT --\u0026gt; RS1 --\u0026gt; DR1 --\u0026gt; LN2 --\u0026gt; AG2 --\u0026gt; FFN --\u0026gt; RS2 --\u0026gt; DR2 style LN1 fill:#d4edda style DR1 fill:#d4edda style LN2 fill:#d4edda style DR2 fill:#d4edda style ATT fill:#fff3cd style FFN fill:#fff3cd style AG1 fill:#cce5ff style AG2 fill:#cce5ff style RS1 fill:#cce5ff style RS2 fill:#cce5ff The key insight: TP\u0026rsquo;s all_reduce is decomposed into reduce_scatter + all_gather, placed at the exit and entrance of the TP regions respectively. The communication volume is unchanged (recall that all_reduce = reduce_scatter + all_gather), but activation memory in the SP regions drops to $1/N$.\nSP\u0026rsquo;s value is most pronounced with large batches and long sequences, where activations dominate memory usage. In these cases, SP directly divides the activation cost by the TP degree.\nMoE Parallelism EP (Expert Parallel) Mixture-of-Experts (MoE) introduces a set of \u0026ldquo;expert\u0026rdquo; subnetworks into the model, where each token activates only $k$ of them (typically $k=1$ or $2$). This allows MoE to dramatically increase parameter count while keeping compute roughly constant — but it also creates unique parallelism challenges.\nThe structure of a MoE layer:\nflowchart TD INPUT[\u0026#34;Input tokens\u0026#34;] ROUTER[\u0026#34;Router Decides which expert each token goes to\u0026#34;] E0[\u0026#34;Expert 0 FFN\u0026#34;] \u0026amp; E1[\u0026#34;Expert 1 FFN\u0026#34;] \u0026amp; E2[\u0026#34;Expert 2 FFN\u0026#34;] \u0026amp; E3[\u0026#34;Expert 3 FFN ...\u0026#34;] COMBINE[\u0026#34;Combine outputs\u0026#34;] INPUT --\u0026gt; ROUTER ROUTER --\u0026gt; E0 \u0026amp; E1 \u0026amp; E2 \u0026amp; E3 E0 \u0026amp; E1 \u0026amp; E2 \u0026amp; E3 --\u0026gt; COMBINE Expert Parallel (EP) assigns different experts to different GPUs. With 64 experts and 8 GPUs, each GPU is responsible for 8 experts.\nThe core communication operation is all_to_all, executed twice:\nDispatch: after the router determines each token\u0026rsquo;s destination expert, all_to_all rearranges tokens from a \u0026ldquo;partitioned by data\u0026rdquo; layout to a \u0026ldquo;grouped by expert\u0026rdquo; layout — each GPU receives all tokens destined for the experts it hosts. Combine: after each expert processes its tokens, all_to_all sends the results back to the originating GPUs. flowchart TD subgraph BEFORE[\u0026#34;Before: data-sharded\u0026#34;] direction LR G0B[\u0026#34;GPU 0 tokens → various Experts\u0026#34;] G1B[\u0026#34;GPU 1 tokens → various Experts\u0026#34;] end BEFORE --\u0026gt;|\u0026#34;all_to_all (dispatch)\u0026#34;| AFTER subgraph AFTER[\u0026#34;After: expert-grouped\u0026#34;] direction LR G0A[\u0026#34;GPU 0 (E0,E1) all tokens → E0,E1\u0026#34;] G1A[\u0026#34;GPU 1 (E2,E3) all tokens → E2,E3\u0026#34;] end AFTER --\u0026gt;|\u0026#34;expert compute\u0026#34;| COMPUTE[\u0026#34;Each GPU runs its experts\u0026#34;] COMPUTE --\u0026gt;|\u0026#34;all_to_all (combine)\u0026#34;| RESULT[\u0026#34;Restore original data sharding\u0026#34;] The communication cost of EP depends on the token routing distribution. If tokens are uniformly spread across all experts, all_to_all traffic is maximized. If tokens cluster on a few experts, communication is lower but load imbalance becomes a problem.\nDeepSeek MoE: All-to-All Dispatch and Token Dropping DeepSeek introduced several important refinements to the MoE architecture.\n1. Fine-Grained Experts\nTraditional MoE designs (e.g., Switch Transformer) use a small number of large experts. DeepSeek uses a large number of small experts — for example, 160 experts with each token activating 6, rather than 16 experts with each token activating 2. More, smaller experts provide exponentially more combinatorial flexibility:\n$$\\binom{160}{6} \\gg \\binom{16}{2}$$\nThe number of possible expert combinations each token can select from grows enormously, yielding greater representational power.\n2. Shared Experts + Routed Experts\nDeepSeek MoE introduces \u0026ldquo;shared experts\u0026rdquo; — experts that process every token — alongside the routed experts selected by the gating network:\nOutput = SharedExpert(x) + Σ Router_topk(RoutedExpert_i(x)) ```mermaid flowchart TD subgraph NON_EXPERT[\u0026#34;Non-expert layers (Attention, LN, Embedding)\u0026#34;] FSDP_SHARD[\u0026#34;FSDP across all GPUs shard params/grads/opt_state\u0026#34;] end subgraph MOE_LAYERS[\u0026#34;MoE layers\u0026#34;] direction LR EP_0[\u0026#34;Expert 0-7 → GPU 0\u0026#34;] EP_1[\u0026#34;Expert 8-15 → GPU 1\u0026#34;] EP_N[\u0026#34;... (EP)\u0026#34;] end COMM[\u0026#34;Communication: FSDP: all_gather + reduce_scatter EP: all_to_all (dispatch/combine)\u0026#34;] NON_EXPERT --- MOE_LAYERS --- COMM ┌─────────── EP + FSDP Combined ─────────────────────────────────┐ │ │ │ Non-expert layers (Attention, LN, Embedding): │ │ → FSDP across all GPUs (shard params/grads/opt_state) │ │ │ │ MoE layers: │ │ Expert 0-7 → GPU 0 (EP) │ │ Expert 8-15 → GPU 1 (EP) │ │ \u0026hellip; │ │ Non-expert parts of MoE layer → FSDP across all GPUs │ │ │ │ Communication: │ │ FSDP: all_gather + reduce_scatter (params/grads) │ │ EP: all_to_all (token dispatch/combine) │ └────────────────────────────────────────────────────────────────┘\nflowchart LR subgraph RING[\u0026#34;Ring Attention — Sequence length S, 4 GPUs\u0026#34;] G0[\u0026#34;GPU 0 Q[0:S/4]\u0026#34;] --\u0026gt;|\u0026#34;KV\u0026#34;| G1[\u0026#34;GPU 1 Q[S/4:S/2]\u0026#34;] G1 --\u0026gt;|\u0026#34;KV\u0026#34;| G2[\u0026#34;GPU 2 Q[S/2:3S/4]\u0026#34;] G2 --\u0026gt;|\u0026#34;KV\u0026#34;| G3[\u0026#34;GPU 3 Q[3S/4:S]\u0026#34;] G3 --\u0026gt;|\u0026#34;KV\u0026#34;| G0 end KV blocks circulate around the ring; send/recv overlaps with attention computation. ┌──────── Ring Attention (sequence length S, 4 GPUs) ───────────┐ │ │ │ GPU 0: Q[0:S/4] GPU 1: Q[S/4:S/2] │ │ GPU 2: Q[S/2:3S/4] GPU 3: Q[3S/4:S] │ │ │ │ Step 1: each GPU computes partial attention with local KV │ │ Step 2: send KV to right neighbor, recv from left, continue │ │ Step 3: continue circulating\u0026hellip; │ │ Step 4: KV completes full circle — attention done │ │ │ │ ┌───┐ KV ┌───┐ KV ┌───┐ KV ┌───┐ │ │ │ 0 │────→│ 1 │────→│ 2 │────→│ 3 │─┐ │ │ └───┘ └───┘ └───┘ └───┘ │ │ │ ↑ │ │ │ └───────────────────────────────────┘ │ └───────────────────────────────────────────────────────────────┘\nA critical optimization: the `send/recv` of KV blocks can **overlap** with attention computation — while a GPU is computing attention with the current KV block, it is already transmitting the next one. **Memory.** Each GPU stores Q and KV for only $S/N$ of the sequence. Memory drops from $O(S^2)$ to $O(S^2/N)$ (or more precisely, from $O(S)$ to $O(S/N)$ under FlashAttention). #### Stripe Attention A problem with Ring Attention is **load imbalance from the causal mask**: GPUs handling the beginning of the sequence compute less attention (because the causal mask blocks later positions), while the GPU holding the end of the sequence does the most work. Stripe Attention solves this by **interleaving** sequence positions across GPUs: Ring Attention (contiguous assignment): GPU 0: tokens [0, 1, 2, 3] → least compute (causal) GPU 1: tokens [4, 5, 6, 7] GPU 2: tokens [8, 9, 10, 11] GPU 3: tokens [12, 13, 14, 15] → most compute\nStripe Attention (interleaved assignment): GPU 0: tokens [0, 4, 8, 12] → balanced compute GPU 1: tokens [1, 5, 9, 13] GPU 2: tokens [2, 6, 10, 14] GPU 3: tokens [3, 7, 11, 15]\nflowchart TD IN[\u0026#34;Input: each GPU has seq/N, full heads\u0026#34;] QKV[\u0026#34;QKV projection (local)\u0026#34;] A2A1[\u0026#34;all_to_all (seq/N, heads) → (seq, heads/N)\u0026#34;] ATT[\u0026#34;Attention each GPU: heads/N over full sequence\u0026#34;] A2A2[\u0026#34;all_to_all (seq, heads/N) → (seq/N, heads)\u0026#34;] OUT[\u0026#34;Output projection (local)\u0026#34;] IN --\u0026gt; QKV --\u0026gt; A2A1 --\u0026gt; ATT --\u0026gt; A2A2 --\u0026gt; OUT style A2A1 fill:#cce5ff style A2A2 fill:#cce5ff ┌──────── Ulysses SP ───────────────────────────────────────┐ │ │ │ Input: each GPU holds seq/N tokens, all heads │ │ │ │ ① QKV projection (local computation) │ │ ② all_to_all: (seq/N, heads) → (seq, heads/N) │ │ transform from \u0026ldquo;split by sequence\u0026rdquo; to \u0026ldquo;split by head\u0026rdquo; │ │ ③ Attention computation (each GPU: heads/N, full seq) │ │ ④ all_to_all: (seq, heads/N) → (seq/N, heads) │ │ transform from \u0026ldquo;split by head\u0026rdquo; back to \u0026ldquo;split by seq\u0026rdquo; │ │ ⑤ Output projection (local computation) │ └───────────────────────────────────────────────────────────┘\nflowchart TD subgraph N0[\u0026#34;Node 0 (8 GPUs) — NVLink 600 GB/s\u0026#34;] direction LR subgraph TP0A[\u0026#34;TP=4\u0026#34;] G0[\u0026#34;0\u0026#34;] ~~~ G1[\u0026#34;1\u0026#34;] ~~~ G2[\u0026#34;2\u0026#34;] ~~~ G3[\u0026#34;3\u0026#34;] end subgraph TP0B[\u0026#34;TP=4\u0026#34;] G4[\u0026#34;4\u0026#34;] ~~~ G5[\u0026#34;5\u0026#34;] ~~~ G6[\u0026#34;6\u0026#34;] ~~~ G7[\u0026#34;7\u0026#34;] end TP0L[\u0026#34;Stage 0 — DP/FSDP across\u0026#34;] end subgraph N1[\u0026#34;Node 1 (8 GPUs)\u0026#34;] direction LR TP1L[\u0026#34;Stage 1 — DP/FSDP across\u0026#34;] end N0 ==\u0026gt;|\u0026#34;PP: send/recv (InfiniBand)\u0026#34;| N1 style N0 fill:#f0f0f0 style N1 fill:#f0f0f0 TP: intra-node (NVLink) | PP: cross-node-group (InfiniBand) | FSDP: cross-node (InfiniBand) ┌──────── 3D Parallelism: 64 GPUs (8 nodes × 8 GPUs) ───────────┐ │ │ │ Node 0 (8 GPUs) Node 1 (8 GPUs) │ │ ┌─TP=4─┐ ┌─TP=4─┐ ┌─TP=4─┐ ┌─TP=4─┐ │ │ │0 1 2 3│ │4 5 6 7│ │0 1 2 3│ │4 5 6 7│ │ │ │Stage 0│ │Stage 0│ │Stage 1│ │Stage 1│ │ │ └───────┘ └───────┘ └───────┘ └───────┘ │ │ └── DP/FSDP across ──┘ └── DP/FSDP across ──┘ │ │ │ │ Node 2-3: Stage 2 Node 4-7: replicas for DP │ │ │ │ TP: intra-node (NVLink 600 GB/s) ← high bandwidth needed │ │ PP: across node groups (InfiniBand) ← small send/recv │ │ FSDP: across nodes (InfiniBand) ← all_gather/reduce_scatter│ └───────────────────────────────────────────────────────────────┘\nThe rationale is topology-aware placement: - **TP** has the heaviest communication (all_reduce every layer), so it goes **intra-node** on NVLink. - **PP** has the lightest communication (send/recv of activations between adjacent stages), so it can go **cross-node**. - **FSDP/DP** falls in between, and is placed **cross-node** with overlapped communication. #### 5D Parallelism Adding SP (Sequence Parallel) and EP (Expert Parallel) gives the so-called 5D parallelism: $$\\text{Total GPUs} = TP \\times SP \\times PP \\times DP \\times EP$$ Each strategy splits along a different dimension: | Strategy | Split Dimension | Communication Op | Volume | Best Interconnect | |----------|----------------|-------------------|--------|-------------------| | TP | Hidden dimension | all_reduce | High | Intra-node (NVLink) | | SP | Sequence dimension | reduce_scatter / all_gather | Medium | Intra-node | | PP | Layers | send/recv | Low | Intra- or cross-node | | DP/FSDP | Data | all_reduce / all_gather + reduce_scatter | Medium | Cross-node | | EP | Experts | all_to_all | Routing-dependent | Cross-node | #### How to Choose a Parallelism Strategy The right combination depends on three factors: **model size, hardware topology, and sequence length**. Here is a decision tree for reference: Does the model fit on a single GPU? ├── Yes → DDP (simplest, fastest) └── No ├── Adam state doesn\u0026rsquo;t fit, but parameters do? │ └── FSDP (SHARD_GRAD_OP) ├── Even parameters don\u0026rsquo;t fit? │ └── FSDP (FULL_SHARD) │ ├── Still OOM? │ │ └── Add TP (typically 2/4/8 within node) │ ├── Still not enough? │ │ └── Add PP (cross-node) │ └── Long sequences causing OOM? │ └── Add Context Parallel └── MoE model? └── EP for experts + FSDP for non-expert params\nSome rules of thumb: - **7B models**: 2-8 GPU DDP or FSDP is sufficient. - **13B-70B models**: FSDP + TP (intra-node TP=2 or 4). - **70B+ models**: FSDP + TP + PP — full 3D parallelism. - **MoE models (e.g., Mixtral, DeepSeek)**: EP + FSDP + TP. - **Very long sequences (128K+)**: Context Parallel + TP + FSDP. --- ## Companion Code All code for this article is in [`code/02-parallel-strategies/`](https://github.com/mzf666/llm-infra/tree/main/code/02-parallel-strategies/): - **`ddp_example.py`** — A complete DDP training loop with `DistributedSampler`, bucketed gradient synchronization, throughput measurement, and parameter consistency verification. Run with: `torchrun --nproc_per_node=2 ddp_example.py` - **`fsdp_example.py`** — Compares NO_SHARD (= DDP), SHARD_GRAD_OP (= ZeRO-2), and FULL_SHARD (= ZeRO-3), showing real per-GPU memory differences. Run with: `torchrun --nproc_per_node=2 fsdp_example.py` - **`tensor_parallel.py`** — A from-scratch implementation of Megatron-style `ColumnParallelLinear` and `RowParallelLinear`, combined into a `TensorParallelFFN`. Includes correctness verification and memory analysis. Run with: `torchrun --nproc_per_node=2 tensor_parallel.py` All scripts support CPU mode (automatic `gloo` backend fallback), so you can study the logic without a GPU. However, memory measurements and performance numbers are only meaningful on CUDA. --- ## Summary and What Comes Next This article systematically covered all the major parallelism strategies used in large model training. Here is a recap of the key ideas. **Classic parallelism strategies:** - **DDP**: replicate the model, split the data, all_reduce the gradients — the simplest approach, but saves no memory. - **FSDP/ZeRO**: shard parameters + gradients + optimizer state, paying all_gather/reduce_scatter communication to save memory — the backbone of large model training. - **TP**: split weight matrices by dimension (Column + Row), with 2 all_reduces per layer — requires NVLink bandwidth, best for intra-node. - **PP**: partition by layers, using send/recv — low communication but pipeline bubbles, addressed by 1F1B and Zero Bubble schedules. - **SP**: split activations along the sequence dimension for non-TP ops (LayerNorm, Dropout) — complementary to TP, reduces activation memory. **MoE parallelism:** - **EP**: distribute experts across GPUs, using all_to_all for token dispatch — load balancing is the central challenge. - **DeepSeek MoE**: fine-grained experts + shared experts + token dropping for better utilization. **Cutting-edge approaches:** - **Context Parallel** (Ring / Stripe Attention): handles ultra-long sequences by distributing attention across GPUs. - **Ulysses SP**: uses all_to_all to transpose between sequence and head dimensions. - **Hybrid Parallelism (3D/5D)**: combines strategies based on hardware topology — TP intra-node, PP cross-node, FSDP everywhere else. The fundamental insight that connects every strategy in this article: each one is making a different **trade-off between memory and communication**. Which combination you choose depends on how large your model is, how fast your inter-GPU links are, and how long your sequences are. There is no silver bullet — only the best engineering compromise for your specific setup. In the next article, **\u0026#34;LLM Inference System Architecture,\u0026#34;** we shift our perspective from training to serving. Once the model is trained, how do we serve requests efficiently? We will dive deep into PagedAttention, RadixAttention (SGLang\u0026#39;s core innovation), Continuous Batching, and other inference optimizations — exploring a fundamentally different, memory-bound challenge. --- ## References 1. **ZeRO: Memory Optimizations Toward Training Trillion Parameter Models** — Rajbhandari et al., 2020 — the theoretical foundation for FSDP, defining the Stage 1/2/3 memory sharding scheme. 2. **Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism** — Shoeybi et al., 2020 — the Column/Row partitioning approach for Tensor Parallelism. 3. **Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM** — Narayanan et al., 2021 — 3D parallelism (TP + PP + DP) system design. 4. **GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism** — Huang et al., 2019 — the micro-batching approach to Pipeline Parallelism. 5. **Zero Bubble Pipeline Parallelism** — Qi et al., 2024 — eliminating pipeline bubbles by separating B and W computations. 6. **Ring Attention with Blockwise Transformers for Near-Infinite Context** — Liu et al., 2023 — the Ring Attention approach for long sequences. 7. **DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model** — DeepSeek AI, 2024 — fine-grained MoE with shared experts. 8. **DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale** — Rajbhandari et al., 2022 — EP and FSDP combined training. 9. **Reducing Activation Recomputation in Large Transformer Models** — Korthikanti et al., 2023 — Sequence Parallel for reducing activation memory. 10. **PyTorch FSDP Documentation** — [pytorch.org/docs/stable/fsdp](https://pytorch.org/docs/stable/fsdp.html) — official FSDP/FSDP2 documentation. 11. **DeepSpeed ZeRO Tutorial** — [deepspeed.ai/tutorials/zero](https://www.deepspeed.ai/tutorials/zero/) — practical guide to ZeRO Stage 1/2/3. ","permalink":"https://mzf666.github.io/llm-infra/en/posts/02-parallel-strategies/","summary":"From DDP to hybrid parallelism — a systematic guide to every parallelism strategy in large model training.","title":"The Landscape of Distributed Parallelism Strategies"},{"content":"Motivation Training a large model is expensive, but serving one can be even more so.\nTraining is a one-time cost — you throw a few thousand GPUs at the problem for a few weeks and you are done. Inference, on the other hand, runs 24/7: every user query, every API call demands a real-time response from the model. OpenAI processes billions of tokens per day, and inference accounts for the lion\u0026rsquo;s share of operational costs. If you can double inference efficiency, you either halve costs or serve twice as many users on the same budget.\nInference optimization is also a fundamentally different game from training. Training optimizes aggregate compute throughput — burn through the data as fast as possible. Inference must simultaneously optimize two metrics that are often in tension:\nThroughput: How many tokens per second can the system generate? This determines how many concurrent users you can serve. Latency: How long does each user wait? The key metrics are TTFT (Time To First Token) and TPOT (Time Per Output Token). More importantly, inference has completely different compute characteristics than training. As we discussed in Article 1, autoregressive decoding is essentially a matrix-vector multiply — arithmetic intensity around 1 FLOPs/byte, GPU utilization below 1%. Inference is not compute-bound. It is memory-bandwidth-bound — and the KV Cache is the single largest consumer of that memory.\nThis article uses SGLang as a case study to dissect four core techniques in modern inference systems:\nKV Cache — the dominant memory consumer during inference, and why it becomes a bottleneck PagedAttention — borrowing the OS virtual memory paradigm to eliminate KV Cache fragmentation RadixAttention — SGLang\u0026rsquo;s core innovation, using a radix tree for cross-request KV Cache reuse Continuous Batching — iteration-level scheduling to maximize GPU utilization Chunked Prefill — preventing long prompts from blocking decode operations By the end, you should understand why vLLM and SGLang achieve 2-10x throughput improvements over naive implementations, and how their design philosophies differ.\nPrerequisites GPU Memory Model and Distributed Communication Fundamentals (Article 1) — especially the HBM bandwidth bottleneck and the Roofline model Basic understanding of Transformer attention (Self-Attention, Q/K/V projections) Familiarity with autoregressive generation (token-by-token decoding) KV Cache Fundamentals Why KV Cache Exists When an autoregressive language model generates token $t$, the attention computation is:\n$$\\text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \\text{softmax}\\left(\\frac{Q_t \\cdot K_{1:t}^T}{\\sqrt{d_k}}\\right) V_{1:t}$$\nThe key observation: the current token\u0026rsquo;s Query must attend to the Keys and Values of every previously generated token. If you recompute all K and V from scratch at each step, total computation is $O(t^2)$ — generating 1000 tokens requires $1 + 2 + \\cdots + 1000 = 500{,}500$ attention computations.\nThe KV Cache insight is simple: previous tokens\u0026rsquo; K and V never change, so compute them once and cache them for reuse.\nflowchart TD subgraph NO[\u0026#34;Autoregressive Decoding (without KV Cache)\u0026#34;] N1[\u0026#34;Step 1: Input [t1] → Compute K1,V1 → Attention → Output t2\u0026#34;] N2[\u0026#34;Step 2: Input [t1,t2] → Recompute K1,V1,K2,V2 → Attention\u0026#34;] N3[\u0026#34;Step 3: Input [t1,t2,t3] → Recompute K1,K2,K3... → Attention\u0026#34;] NN[\u0026#34;Step N: Input [t1,...,tN] → Recompute all K,V → O(N²) total\u0026#34;] N1 --\u0026gt; N2 --\u0026gt; N3 -.-\u0026gt; NN end subgraph YES[\u0026#34;Autoregressive Decoding (with KV Cache)\u0026#34;] Y1[\u0026#34;Step 1: Input [t1] → Compute K1,V1 → Cache → Attention → Output t2\u0026#34;] Y2[\u0026#34;Step 2: Input [t2] → Compute K2,V2 → Append to cache → Attention\u0026#34;] Y3[\u0026#34;Step 3: Input [t3] → Compute K3,V3 → Append to cache → Attention\u0026#34;] YN[\u0026#34;Step N: Input [tN] → Compute KN,VN → Append to cache → O(N) total\u0026#34;] Y1 --\u0026gt; Y2 --\u0026gt; Y3 -.-\u0026gt; YN YNOTE[\u0026#34;Each step computes only 1 token\u0026#39;s K,V, then attends over the cache\u0026#34;] end NO --\u0026gt; YES style NO fill:#fff3cd,stroke:#856404 style YES fill:#d4edda,stroke:#155724 KV Cache reduces total computation from $O(N^2)$ to $O(N)$ — this is the first principle of inference optimization. Every modern inference engine implements it.\nHow Large Is the KV Cache? The KV Cache size per token is:\n$$\\text{KV per token} = 2 \\times n_{\\text{layers}} \\times n_{\\text{heads}} \\times d_{\\text{head}} \\times \\text{dtype_size}$$\nThe factor of 2 accounts for storing both K and V. Here are concrete numbers for several popular models (FP16, 2 bytes per element):\nModel $n_{\\text{layers}}$ $n_{\\text{heads}}$ $d_{\\text{head}}$ KV per token 2048 tokens LLaMA-7B 32 32 128 0.5 MB 1.0 GB LLaMA-13B 40 40 128 0.8 MB 1.6 GB LLaMA-70B 80 64 128 2.5 MB 5.0 GB GPT-3 175B 96 96 128 4.5 MB 9.2 GB A single LLaMA-70B request generating 2048 tokens consumes 5 GB of KV Cache alone. If you want to serve 16 such requests concurrently, the KV Cache requires 80 GB — the entire memory of an A100. And the model weights themselves take another 140 GB (in FP16).\nThis is the fundamental tension of inference memory management: model weights are a fixed cost, KV Cache is a dynamic cost, and KV Cache grows linearly with both the number of concurrent requests and sequence length — easily becoming the bottleneck.\nHow many requests you can batch simultaneously depends on how much KV Cache you can fit — and batch size directly determines throughput. So the efficiency of KV Cache memory management sets the performance ceiling for the entire inference system.\nThe Waste Problem with Naive Allocation The most straightforward approach is to pre-allocate a contiguous memory region of size max_seq_len for each request\u0026rsquo;s KV Cache. The problem is that you do not know how many tokens a request will actually generate — it could be 10 or it could be 2000.\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class NaiveKVCache: \u0026#34;\u0026#34;\u0026#34; Pre-allocated KV cache for a single sequence. Problem: if max_seq_len=2048 but the sequence only generates 50 tokens, we waste 2048-50 = 1998 slots of memory. Multiply by batch_size and num_layers, and you\u0026#39;re throwing away most of your GPU memory. \u0026#34;\u0026#34;\u0026#34; def __init__(self, max_seq_len: int, num_heads: int, head_dim: int): self.max_seq_len = max_seq_len self.num_heads = num_heads self.head_dim = head_dim self.current_len = 0 # Pre-allocate full tensors — this is the waste! self.k_cache = torch.zeros(max_seq_len, num_heads, head_dim, device=device) self.v_cache = torch.zeros(max_seq_len, num_heads, head_dim, device=device) The companion code demonstrates that when request lengths are non-uniform (which is nearly always the case in practice), naive pre-allocation wastes 60-80% of GPU memory:\n5 requests, max_seq_len=2048 Actual lengths: [37, 152, 8, 420, 91] Memory allocated: 160.0 MB Memory actually used: 22.1 MB Waste: 86.2% --\u0026gt; This is why we need PagedAttention! Wasting 86% of memory means a GPU that could handle 7 concurrent requests is limited to just 1. What you are wasting is not just memory — it is throughput and revenue.\nPagedAttention Borrowing Wisdom from Operating Systems PagedAttention is the core contribution of vLLM (SOSP'23), and its inspiration comes from a technology that has been battle-tested for 50 years — OS virtual memory.\nIn an operating system, each process sees a virtual address space that appears contiguous and private. But physical memory is allocated in fixed-size pages, and a page table maps virtual addresses to physical addresses. The process never needs contiguous physical memory — the OS stitches together scattered physical pages into what looks like a contiguous virtual space behind the scenes.\nPagedAttention transplants this idea to KV Cache management:\nflowchart LR subgraph OS[\u0026#34;OS Virtual Memory\u0026#34;] A1[\u0026#34;Process\u0026#34;] A2[\u0026#34;Virtual page\u0026#34;] A3[\u0026#34;Physical page frame\u0026#34;] A4[\u0026#34;Page table\u0026#34;] A5[\u0026#34;Demand paging (page fault)\u0026#34;] A6[\u0026#34;Fixed page size (4 KB)\u0026#34;] A7[\u0026#34;CoW (fork)\u0026#34;] A8[\u0026#34;Page swap\u0026#34;] end subgraph PA[\u0026#34;PagedAttention\u0026#34;] B1[\u0026#34;Request (sequence)\u0026#34;] B2[\u0026#34;Logical KV block\u0026#34;] B3[\u0026#34;Physical KV block\u0026#34;] B4[\u0026#34;Block table\u0026#34;] B5[\u0026#34;On-demand allocation (new token)\u0026#34;] B6[\u0026#34;Fixed block size (16 tokens)\u0026#34;] B7[\u0026#34;CoW (beam search)\u0026#34;] B8[\u0026#34;Preemption (swap out request)\u0026#34;] end A1 -.-\u0026gt; B1 A2 -.-\u0026gt; B2 A3 -.-\u0026gt; B3 A4 -.-\u0026gt; B4 A5 -.-\u0026gt; B5 A6 -.-\u0026gt; B6 A7 -.-\u0026gt; B7 A8 -.-\u0026gt; B8 style OS fill:#cce5ff,stroke:#004085 style PA fill:#d4edda,stroke:#155724 Block Allocation Mechanism PagedAttention divides GPU KV Cache memory into a pool of fixed-size blocks. Each block stores the KV data for a fixed number of tokens (typically 16 or 32). Each request maintains a block table that maps its logical blocks to physical blocks.\nflowchart TD subgraph POOL[\u0026#34;Physical Block Pool (GPU Memory)\u0026#34;] B0[\u0026#34;B0\u0026lt;br/\u0026gt;Seq0\u0026#34;] B1[\u0026#34;B1\u0026lt;br/\u0026gt;Seq1\u0026#34;] B2[\u0026#34;B2\u0026lt;br/\u0026gt;Free\u0026#34;] B3[\u0026#34;B3\u0026lt;br/\u0026gt;Seq0\u0026#34;] B4[\u0026#34;B4\u0026lt;br/\u0026gt;Seq1\u0026#34;] B5[\u0026#34;B5\u0026lt;br/\u0026gt;Free\u0026#34;] B6[\u0026#34;B6\u0026lt;br/\u0026gt;Seq0\u0026#34;] B7[\u0026#34;B7\u0026lt;br/\u0026gt;Free\u0026#34;] end subgraph T0[\u0026#34;Seq 0 Block Table\u0026#34;] L0_0[\u0026#34;Logic 0 → B0\u0026#34;] L0_1[\u0026#34;Logic 1 → B3\u0026#34;] L0_2[\u0026#34;Logic 2 → B6\u0026#34;] end subgraph T1[\u0026#34;Seq 1 Block Table\u0026#34;] L1_0[\u0026#34;Logic 0 → B1\u0026#34;] L1_1[\u0026#34;Logic 1 → B4\u0026#34;] end L0_0 --\u0026gt; B0 L0_1 --\u0026gt; B3 L0_2 --\u0026gt; B6 L1_0 --\u0026gt; B1 L1_1 --\u0026gt; B4 NOTE[\u0026#34;Physical blocks need NOT be contiguous — only logical order matters\u0026#34;] style POOL fill:#cce5ff,stroke:#004085 style B2 fill:#f8f9fa,stroke:#6c757d style B5 fill:#f8f9fa,stroke:#6c757d style B7 fill:#f8f9fa,stroke:#6c757d The workflow:\nNew request arrives: Allocate an empty block table — no blocks pre-allocated First token generated: Grab a block from the free pool, store the KV, update the block table Current block fills up (16 tokens worth of KV written): Grab another block from the free pool Request completes: Return all blocks to the free pool, immediately available for other requests 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 class BlockManager: \u0026#34;\u0026#34;\u0026#34; PagedAttention-style block manager. Key ideas: - Physical KV cache is a pool of fixed-size blocks. - Each sequence has a block table mapping logical -\u0026gt; physical blocks. - Blocks are allocated on demand as sequences grow. - Free blocks are recycled when sequences finish. - Copy-on-Write: shared blocks are copied only when modified. \u0026#34;\u0026#34;\u0026#34; def __init__(self, num_blocks, num_layers, num_heads, head_dim, block_size=16): self.num_blocks = num_blocks self.block_size = block_size # Physical block pool — the actual GPU memory self.k_pool = torch.zeros( num_layers, num_blocks, block_size, num_heads, head_dim, device=device ) self.v_pool = torch.zeros( num_layers, num_blocks, block_size, num_heads, head_dim, device=device ) # Track physical block metadata self.physical_blocks = [PhysicalBlock(block_id=i) for i in range(num_blocks)] self.free_block_ids: list[int] = list(range(num_blocks)) self.seq_tables: dict[int, SequenceBlockTable] = {} From 86% Waste to Under 3% With PagedAttention, memory waste comes from exactly one source: internal fragmentation in each request\u0026rsquo;s last block. If block_size=16, the last block is on average half-full, wasting roughly $\\frac{1}{2} \\times \\frac{1}{N_{\\text{blocks}}}$ of total memory.\nThe companion code shows the comparison:\nConfig: 32 heads, head_dim=128, max_seq=2048, block_size=16 10 requests with lengths: [47, 183, 12, 891, 256, 5, 1024, 73, 330, 15] Metric Naive Paged Actual ----------------------------------------------------------------------- KV cache memory (MB) 320.0 45.5 44.3 Memory utilization 13.8% 97.4% 100.0% Waste (MB) 275.7 1.2 0.0 Waste (%) 86.2% 2.6% 0.0% Naive allocation wastes 86%; PagedAttention wastes just 2.6%. This means the same GPU can batch 2-4x more requests, directly translating to a 2-4x throughput increase.\nCopy-on-Write: Efficient Sequence Forking In beam search and parallel sampling, a request forks into multiple candidate sequences. These candidates share the prefix\u0026rsquo;s KV Cache — copying the prefix KV for every candidate would multiply memory consumption.\nPagedAttention solves this with Copy-on-Write (CoW), again borrowing from the OS fork() system call:\nflowchart TD subgraph BEFORE[\u0026#34;Before fork\u0026#34;] S0a[\u0026#34;Seq 0: Block A → Block B → Block C (ref=1)\u0026#34;] end subgraph AFTER[\u0026#34;After fork (Seq 0 + Seq 1)\u0026#34;] S0b[\u0026#34;Seq 0: Block A → Block B → Block C\u0026#34;] S1b[\u0026#34;Seq 1: Block A → Block B → Block C\u0026#34;] REF[\u0026#34;Block C ref=2 (shared, no copy!)\u0026#34;] end subgraph WRITE[\u0026#34;Seq 1 writes new token to Block C\u0026#34;] S0c[\u0026#34;Seq 0: Block A → Block B → Block C (ref=1)\u0026#34;] S1c[\u0026#34;Seq 1: Block A → Block B → Block C\u0026#39; (ref=1, copied now)\u0026#34;] NOTE2[\u0026#34;Only the modified block is copied; earlier blocks remain shared\u0026#34;] end BEFORE --\u0026gt; AFTER --\u0026gt; WRITE style BEFORE fill:#cce5ff,stroke:#004085 style AFTER fill:#fff3cd,stroke:#856404 style WRITE fill:#d4edda,stroke:#155724 The core logic lives in append_token: before writing, check the block\u0026rsquo;s reference count. If it exceeds 1 (the block is shared), copy it first, then write:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def append_token(self, seq_id, k, v, layer_idx): table = self.seq_tables[seq_id] logical_block_idx = token_idx // self.block_size slot_in_block = token_idx % self.block_size # Allocate new block if current one is full if logical_block_idx \u0026gt;= len(table.logical_to_physical): phys_id = self._allocate_block() table.logical_to_physical.append(phys_id) phys_id = table.logical_to_physical[logical_block_idx] # CoW check: if this block is shared, copy before writing if self.physical_blocks[phys_id].ref_count \u0026gt; 1: new_phys_id = self._allocate_block() self.k_pool[:, new_phys_id] = self.k_pool[:, phys_id] self.v_pool[:, new_phys_id] = self.v_pool[:, phys_id] self._free_block(phys_id) # Decrement old block\u0026#39;s ref count table.logical_to_physical[logical_block_idx] = new_phys_id phys_id = new_phys_id # Write KV to the physical block self.k_pool[layer_idx, phys_id, slot_in_block] = k self.v_pool[layer_idx, phys_id, slot_in_block] = v CoW pays off most dramatically in beam search: with beam_size=4, four candidate sequences share the vast majority of their prefix KV. Only the last block (the part being actively generated) needs independent storage. Memory cost approaches that of 1 sequence rather than 4.\nRadixAttention (SGLang\u0026rsquo;s Core Innovation) From PagedAttention to Prefix Caching PagedAttention solves memory fragmentation within a single request. But it misses a larger optimization opportunity: KV Cache reuse across requests.\nIn real-world LLM applications, many requests share identical prefixes:\nflowchart TD subgraph S1[\u0026#34;Scenario 1: System Prompt\u0026#34;] SP[\u0026#34;Shared prefix: You are a helpful assistant...\u0026#34;] UA[\u0026#34;+ User A: What is machine learning?\u0026#34;] UB[\u0026#34;+ User B: Write a Python function for...\u0026#34;] UC[\u0026#34;+ User C: Translate this text...\u0026#34;] SP --\u0026gt; UA SP --\u0026gt; UB SP --\u0026gt; UC end subgraph S2[\u0026#34;Scenario 2: Few-shot Learning\u0026#34;] FS[\u0026#34;Shared prefix: Here are some examples x5\u0026#34;] Q1[\u0026#34;+ Input: new query 1\u0026#34;] Q2[\u0026#34;+ Input: new query 2\u0026#34;] FS --\u0026gt; Q1 FS --\u0026gt; Q2 end subgraph S3[\u0026#34;Scenario 3: Multi-turn Chat\u0026#34;] MT[\u0026#34;Shared prefix: System + Turn 1 + Turn 2\u0026#34;] T3[\u0026#34;+ User Turn 3 (new message)\u0026#34;] MT --\u0026gt; T3 end style S1 fill:#cce5ff,stroke:#004085 style S2 fill:#d4edda,stroke:#155724 style S3 fill:#fff3cd,stroke:#856404 If a system prompt is 500 tokens long and every request recomputes and stores those 500 tokens\u0026rsquo; KV Cache independently, then 100 concurrent requests store 100 identical copies — a massive waste.\nThe prefix caching insight: if two requests share the same token prefix, their KV Cache at those positions is identical and can be shared directly, with no recomputation needed.\nvLLM eventually added prefix caching too, but its implementation relies on predefined prefix matching — you need to explicitly declare which requests share prefixes. SGLang\u0026rsquo;s RadixAttention offers a more flexible and elegant approach.\nRadix Tree: A Data Structure for Indexing Arbitrary Prefixes SGLang (LMSYS, 2024) introduces a Radix Tree (radix trie) to index and manage all cached KV data.\nA radix tree is a compressed trie. In a standard trie, each node represents a single character; a radix tree compresses single-child paths into a single node, reducing tree depth. In SGLang, each node represents a token sequence segment, and nodes store pointers to the corresponding KV Cache blocks.\ngraph TD ROOT[\u0026#34;Root\u0026#34;] SYS[\u0026#34;System Prompt\u0026lt;br/\u0026gt;KV blocks: B0, B1, B2\u0026#34;] ML[\u0026#34;User: What is ML?\u0026lt;br/\u0026gt;KV blocks: B3, B4\u0026#34;] PY[\u0026#34;User: Write Python\u0026lt;br/\u0026gt;KV blocks: B5, B6\u0026#34;] AST[\u0026#34;Asst: ML is...\u0026lt;br/\u0026gt;KV blocks: B7\u0026#34;] ROOT --\u0026gt; SYS SYS --\u0026gt; ML SYS --\u0026gt; PY ML --\u0026gt; AST REQD[\u0026#34;New Req D: System Prompt + What is ML? + new query\u0026lt;br/\u0026gt;Matches B0-B4 directly, only new query needs fresh KV\u0026#34;] style SYS fill:#cce5ff,stroke:#004085 style ML fill:#d4edda,stroke:#155724 style PY fill:#d4edda,stroke:#155724 style AST fill:#fff3cd,stroke:#856404 style REQD fill:#f8f9fa,stroke:#6c757d Workflow When a new request arrives, SGLang\u0026rsquo;s scheduler executes the following steps:\nflowchart TD subgraph FLOW[\u0026#34;RadixAttention: Processing a New Request\u0026#34;] A[\u0026#34;1. PREFIX MATCHING\u0026lt;br/\u0026gt;Search Radix Tree for longest matching prefix\u0026lt;br/\u0026gt;e.g., sys_prompt KV blocks are reused directly\u0026#34;] B[\u0026#34;2. REUSE CACHED KV\u0026lt;br/\u0026gt;Skip prefill for the matched prefix\u0026lt;br/\u0026gt;Only compute KV for the unmatched suffix\u0026#34;] C[\u0026#34;3. INSERT INTO TREE\u0026lt;br/\u0026gt;After completion, insert new KV into Radix Tree\u0026lt;br/\u0026gt;Future requests with the same prefix can reuse it\u0026#34;] A --\u0026gt; B --\u0026gt; C end subgraph TIMELINE[\u0026#34;Timeline Comparison\u0026#34;] T1[\u0026#34;No cache: ====== prefill 500 tokens ====== decode\u0026#34;] T2[\u0026#34;With cache: == prefill 100 tokens == decode\u0026lt;br/\u0026gt;400 tokens KV reused from cache, skipped\u0026#34;] end FLOW --\u0026gt; TIMELINE style FLOW fill:#cce5ff,stroke:#004085 style T1 fill:#fff3cd,stroke:#856404 style T2 fill:#d4edda,stroke:#155724 In the system prompt scenario (where nearly every request shares the same prompt), prefix caching skips the bulk of prefill computation, reducing TTFT (Time To First Token) by 50-90%.\nLRU Eviction GPU memory is finite — you cannot cache every prefix you have ever seen. When the block pool is full, some cached KV data must be evicted to make room for new requests.\nSGLang uses LRU (Least Recently Used) eviction: the KV Cache node that has gone the longest without being accessed is evicted first.\nflowchart TD subgraph NODES[\u0026#34;Radix Tree Nodes (sorted by last access time)\u0026#34;] N1[\u0026#34;sys_prompt + user_A — 10 sec ago — keep\u0026#34;] N2[\u0026#34;sys_prompt + user_B — 30 sec ago — keep\u0026#34;] N3[\u0026#34;sys_prompt + user_C — 120 sec ago — eviction candidate\u0026#34;] N4[\u0026#34;old_prompt + old_query — 300 sec ago — evict first\u0026#34;] end subgraph EVICT[\u0026#34;Eviction Process\u0026#34;] E1[\u0026#34;1. Take least-recently-used node from LRU list\u0026#34;] E2[\u0026#34;2. Free that node\u0026#39;s KV blocks\u0026#34;] E3[\u0026#34;3. Remove node from Radix Tree\u0026#34;] E4[\u0026#34;4. If parent becomes childless leaf, recursively clean up\u0026#34;] E1 --\u0026gt; E2 --\u0026gt; E3 --\u0026gt; E4 end style N1 fill:#d4edda,stroke:#155724 style N2 fill:#d4edda,stroke:#155724 style N3 fill:#fff3cd,stroke:#856404 style N4 fill:#f8d7da,stroke:#721c24 LRU works well here because of temporal locality: recently accessed prefixes are likely to be accessed again (the same user continuing a multi-turn conversation, or a popular system prompt being reused across many concurrent requests).\nCache-Aware Scheduling With the Radix Tree in place, the scheduler can make a remarkably effective optimization: prioritize requests that have longer prefix matches in the cache.\nflowchart TD subgraph QUEUE[\u0026#34;Waiting Queue\u0026#34;] RA[\u0026#34;Req A: Matches 500 cached tokens → prefill only 100 tokens\u0026#34;] RB[\u0026#34;Req B: Matches 0 tokens → prefill needs 600 tokens\u0026#34;] RC[\u0026#34;Req C: Matches 300 cached tokens → prefill only 200 tokens\u0026#34;] end subgraph SCHED[\u0026#34;Scheduling Strategy Comparison\u0026#34;] FCFS[\u0026#34;Naive FCFS: A → B → C\u0026#34;] CACHE[\u0026#34;Cache-aware: A → C → B (prioritize high cache hits)\u0026#34;] end subgraph BENEFIT[\u0026#34;Benefits\u0026#34;] B1[\u0026#34;1. High cache-hit requests prefill faster → lower TTFT\u0026#34;] B2[\u0026#34;2. KV stays in tree → similar future requests also hit → positive feedback\u0026#34;] B3[\u0026#34;3. Less total prefill compute → more GPU cycles for decode\u0026#34;] end QUEUE --\u0026gt; SCHED --\u0026gt; BENEFIT style RA fill:#d4edda,stroke:#155724 style RB fill:#f8d7da,stroke:#721c24 style RC fill:#fff3cd,stroke:#856404 style CACHE fill:#d4edda,stroke:#155724 This scheduling strategy gives SGLang a significant edge in workloads with shared prefixes — few-shot learning, multi-turn chat, and any application using system prompts.\nComparison with vLLM vLLM (v0.3+) later added prefix caching as well, but with a different design philosophy:\nFeature SGLang (RadixAttention) vLLM (Prefix Caching) Data structure Radix Tree Hash Map Prefix matching Automatic matching of any shared prefix Token block hash-based matching Matching granularity Token-level Block-level (block_size-aligned) Eviction policy LRU on tree nodes LRU on blocks Cache-aware scheduling Native support Added in later versions Multi-turn conversation Natural fit (tree structure) Requires hash matching SGLang\u0026rsquo;s Radix Tree approach has clear advantages in flexibility:\nArbitrary prefix sharing: Not limited to predefined prompt templates — any two requests with a common prefix are automatically matched Tree structure naturally fits multi-turn dialogue: Conversation history naturally forms a tree, with different branches representing different follow-up exchanges Token-level precision: Matching can occur at exact token boundaries, not just at block-aligned positions vLLM\u0026rsquo;s hash-based approach has advantages in implementation simplicity and stability at scale. In practice, the performance gap depends heavily on the specific workload characteristics.\nContinuous Batching The Problem with Static Batching Now that we understand KV Cache memory management, let us look at a different dimension of optimization: scheduling.\nThe most straightforward inference approach is static batching: assemble a group of requests into a batch, run prefill together, then decode together, and wait until every request in the batch finishes before processing the next batch.\nThe problem: output lengths vary wildly across requests.\nflowchart TD subgraph STATIC[\u0026#34;Static Batching — Batch = Req 0-3, Output lengths: 3, 8, 2, 5\u0026#34;] direction LR R0[\u0026#34;Req 0: ## ## ## .. .. .. .. ..\u0026lt;br/\u0026gt;Done at step 3\u0026#34;] R1[\u0026#34;Req 1: ## ## ## ## ## ## ## ##\u0026lt;br/\u0026gt;Longest — everyone waits\u0026#34;] R2[\u0026#34;Req 2: ## ## .. .. .. .. .. ..\u0026lt;br/\u0026gt;Done at step 2\u0026#34;] R3[\u0026#34;Req 3: ## ## ## ## ## .. .. ..\u0026lt;br/\u0026gt;Done at step 5\u0026#34;] end WASTE[\u0026#34;## = useful compute · .. = GPU idle\u0026lt;br/\u0026gt;GPU utilization: 18/32 = 56.25%\u0026lt;br/\u0026gt;43.75% of GPU cycles wasted\u0026lt;br/\u0026gt;Req 4, 5 queued — must wait for Req 1 to finish\u0026#34;] style STATIC fill:#fff3cd,stroke:#856404 style WASTE fill:#f8d7da,stroke:#721c24 When output length variance is high (and it nearly always is in practice — one user asks \u0026ldquo;what time is it?\u0026rdquo; while another requests a long essay), static batching can drive GPU utilization as low as 20-50%.\nIteration-Level Scheduling Continuous Batching (Orca, OSDI'22) makes one fundamental change: it drops the scheduling granularity from batch level to iteration level.\nAfter every single decode step, the scheduler checks:\nHave any requests finished (generated EOS or hit max_length)? If so, evict them and free their KV Cache. Are there waiting requests that can be admitted? If so, admit them and begin their prefill. flowchart TD subgraph CB[\u0026#34;Continuous Batching — Iteration-Level Scheduling\u0026#34;] direction LR R0[\u0026#34;Req 0: ## ## ## — done step 3\u0026#34;] R1[\u0026#34;Req 1: ## ## ## ## ## ## ## ## — done step 8\u0026#34;] R2[\u0026#34;Req 2: ## ## — done step 2\u0026#34;] R3[\u0026#34;Req 3: ## ## ## ## ## — done step 5\u0026#34;] R4[\u0026#34;Req 4: admitted step 3 → ## ## ## ##\u0026#34;] R5[\u0026#34;Req 5: admitted step 4 → ## ## ## ## ## ##\u0026#34;] end RESULT[\u0026#34;Active count: 4 4 4 4 4 3 3 3 1\u0026lt;br/\u0026gt;GPU stays nearly full! Finished requests immediately replaced\u0026#34;] style CB fill:#d4edda,stroke:#155724 style RESULT fill:#cce5ff,stroke:#004085 The companion code demonstrates that continuous batching typically delivers a 2-3x throughput improvement:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def run_continuous_batching(requests, max_batch_size, total_blocks=256, ...): \u0026#34;\u0026#34;\u0026#34; After EVERY decode step: 1. Remove finished sequences (free their KV blocks). 2. If there are free slots AND enough KV blocks, admit waiting requests. 3. If out of KV blocks, preempt the newest running sequence (LIFO). \u0026#34;\u0026#34;\u0026#34; # ... per-step scheduling ... for req in running: req.tokens_generated += 1 if req.tokens_generated \u0026gt;= req.target_output_length: req.status = SequenceStatus.FINISHED finished_this_step.append(req) # Remove finished, free their blocks for req in finished_this_step: running.remove(req) block_pool.free(req.num_kv_blocks) # Try to admit waiting requests while waiting and len(running) \u0026lt; max_batch_size: req = waiting[0] if block_pool.can_allocate(blocks_needed): waiting.popleft() running.append(req) Preemption: Graceful Handling of Memory Pressure Continuous batching introduces another critical mechanism: preemption. When there are not enough free KV Cache blocks to allocate for all running requests, the scheduler can:\nPause one or more running requests (typically the most recently admitted — LIFO strategy, since they have done the least work) Free the paused request\u0026rsquo;s KV Cache blocks Use the reclaimed space to continue serving other requests Resume the paused request later (re-prefill to rebuild its KV Cache) flowchart TD A[\u0026#34;Normal operation\u0026lt;br/\u0026gt;Running: Req 0, 1, 2, 3\u0026lt;br/\u0026gt;KV blocks: 200/256 (78%)\u0026#34;] B[\u0026#34;Req 1 generates long response, KV keeps growing...\u0026lt;br/\u0026gt;KV blocks: 253/256 (99%) — about to OOM!\u0026#34;] C[\u0026#34;Preemption triggered\u0026lt;br/\u0026gt;Pause Req 3 (LIFO: newest, least work done)\u0026lt;br/\u0026gt;Free Req 3\u0026#39;s 20 blocks\u0026#34;] D[\u0026#34;KV blocks: 233/256 (91%)\u0026lt;br/\u0026gt;Room to continue\u0026#34;] E[\u0026#34;Later, Req 0/1 finishes, space freed\u0026lt;br/\u0026gt;→ Resume Req 3, re-prefill to rebuild KV\u0026#34;] F[\u0026#34;Alternative: no preemption → OOM crash → ALL requests fail!\u0026#34;] A --\u0026gt; B --\u0026gt; C --\u0026gt; D --\u0026gt; E B -.-\u0026gt; F style A fill:#d4edda,stroke:#155724 style B fill:#f8d7da,stroke:#721c24 style C fill:#fff3cd,stroke:#856404 style D fill:#d4edda,stroke:#155724 style F fill:#f8d7da,stroke:#721c24 Preemption has a cost: the paused request must re-prefill (rebuild its KV Cache) when resumed, increasing that request\u0026rsquo;s latency. But compared to an OOM crash that kills every request in flight, this is a very reasonable trade-off.\nThe companion code\u0026rsquo;s preemption demo shows that even under extremely constrained memory (32 blocks), preemption ensures all requests eventually complete:\nConfig: 20 requests, 32 KV blocks, max_batch=4 Preemptions triggered: 12 Preemption allows the system to handle load spikes without OOM crashes. All 20 requests completed: True Chunked Prefill Long Prompts Blocking Decode The prefill phase (processing the user\u0026rsquo;s input prompt) and the decode phase (generating tokens one at a time) have fundamentally different compute characteristics:\nPrefill: Processes the entire prompt at once (potentially thousands of tokens) — a compute-bound large GEMM operation Decode: Generates 1 token per step — a memory-bound GEMV operation The problem: if a request\u0026rsquo;s prompt is 8000 tokens, prefill might take several hundred milliseconds. During that time, every decode request is blocked — they are all waiting for the GPU to finish this one long prefill.\nflowchart TD subgraph NOCHUNK[\u0026#34;Without Chunked Prefill\u0026#34;] GPU1[\u0026#34;GPU: ====== Prefill 8K tokens (200ms) ====== → D D D D\u0026#34;] RA1[\u0026#34;Req A: waiting 200ms... then decode\u0026#34;] RB1[\u0026#34;Req B: waiting 200ms... then decode\u0026#34;] end PROBLEM[\u0026#34;Problem: Req A, B TPOT spikes from 10ms to 200ms+\u0026lt;br/\u0026gt;User experience: streaming text suddenly freezes for 200ms\u0026#34;] style NOCHUNK fill:#fff3cd,stroke:#856404 style PROBLEM fill:#f8d7da,stroke:#721c24 For a user in the middle of a streaming conversation, TPOT jumping from 10ms to 200ms produces a very noticeable stutter — the text was flowing smoothly, then it just\u0026hellip; stops.\nSplitting Prefill into Chunks Chunked Prefill takes a straightforward approach: break the long prompt\u0026rsquo;s prefill into multiple smaller chunks (e.g., 512 tokens each), and interleave decode steps between chunks.\nflowchart TD subgraph CHUNK[\u0026#34;Chunked Prefill — 8192 tokens, chunk_size=512, 16 chunks\u0026#34;] GPU2[\u0026#34;GPU: P1 → D → P2 → D → P3 → D → P4 → D → P5 ...\u0026#34;] LEGEND[\u0026#34;P = prefill chunk (512 tokens, ~12ms) · D = decode step (~3ms)\u0026#34;] end subgraph DECODE[\u0026#34;From decode requests\u0026#39; perspective\u0026#34;] RA2[\u0026#34;Req A: D — D — D — D — D (every ~15ms)\u0026#34;] RB2[\u0026#34;Req B: D — D — D — D — D (every ~15ms)\u0026#34;] end RESULT2[\u0026#34;TPOT: ~15ms (stable!) vs 200ms spike without chunking\u0026#34;] CHUNK --\u0026gt; DECODE DECODE --\u0026gt; RESULT2 style CHUNK fill:#d4edda,stroke:#155724 style RESULT2 fill:#cce5ff,stroke:#004085 The trade-off:\nTotal prefill time increases: Because decode steps are interleaved between chunks, an 8K-token prefill goes from 200ms to roughly 240ms (TTFT goes up slightly) Decode latency becomes stable: TPOT drops from a 200ms spike to a consistent ~15ms This is a TTFT vs TPOT trade-off:\nMetric Without Chunked Prefill With Chunked Prefill TTFT (first token) 200ms ~240ms (+20%) TPOT (worst case) 200ms (blocked) ~15ms (stable) TPOT (P99) Very poor Well-controlled User experience Occasional severe stutter Smooth and consistent For most online serving scenarios, stable TPOT matters far more than slightly lower TTFT — users care more about \u0026ldquo;smooth streaming\u0026rdquo; than \u0026ldquo;seeing the first character 40ms sooner.\u0026rdquo;\nChoosing the Chunk Size Chunk size requires balancing competing concerns:\nToo small (e.g., 64 tokens): Prefill GPU utilization suffers (small matrix multiplies are inefficient), and TTFT increases significantly Too large (e.g., 4096 tokens): Approaches the unchunked case, and decode requests still experience long blocks Typical range: 512-1024 tokens, balancing prefill efficiency with decode stability Both SGLang and vLLM support chunked prefill and can dynamically adjust chunk size based on the current running batch — if no decode requests are active, there is no need to chunk at all, so the system does a full prefill to minimize TTFT.\nSystem Architecture Overview Putting all the pieces together, here is the overall architecture of a modern LLM inference engine like SGLang:\nflowchart TD HTTP[\u0026#34;HTTP Server\u0026lt;br/\u0026gt;Receives user requests\u0026#34;] subgraph SCHED[\u0026#34;Scheduler (the brain)\u0026#34;] WQ[\u0026#34;Waiting Queue\u0026#34;] RT[\u0026#34;Radix Tree\u0026lt;br/\u0026gt;(Prefix Cache Index)\u0026#34;] STEPS[\u0026#34;Every iteration:\u0026lt;br/\u0026gt;1. Evict completed requests\u0026lt;br/\u0026gt;2. Match new request prefixes\u0026lt;br/\u0026gt;3. Cache-aware ranking + admit\u0026lt;br/\u0026gt;4. Preempt if memory insufficient\u0026lt;br/\u0026gt;5. Chunked prefill scheduling\u0026#34;] end subgraph BM[\u0026#34;Block Manager (the memory)\u0026#34;] POOL2[\u0026#34;Physical Block Pool: B0 B1 B2 B3 B4 B5 ...\u0026#34;] BMOPS[\u0026#34;On-demand alloc/free · CoW · LRU eviction\u0026#34;] end subgraph ME[\u0026#34;Model Executor (the compute)\u0026#34;] ATTN[\u0026#34;Attention kernel (FlashAttention / FlashInfer)\u0026#34;] KV[\u0026#34;Reads KV from Block Table (non-contiguous phys.)\u0026#34;] PAR[\u0026#34;TP / PP parallel inference\u0026#34;] end HTTP --\u0026gt; SCHED --\u0026gt; BM --\u0026gt; ME style SCHED fill:#cce5ff,stroke:#004085 style BM fill:#fff3cd,stroke:#856404 style ME fill:#d4edda,stroke:#155724 The Scheduler is the brain — it decides which requests to run at each step, whether to prefill or decode, and whether preemption is needed. The Block Manager is the memory manager — it allocates, frees, and shares KV Cache blocks. The Model Executor is the compute engine — it takes the scheduler\u0026rsquo;s decisions and the block table mappings and executes the actual attention computation.\nThese three components work in concert to form an efficient inference pipeline.\nKey Takeaways KV Cache is the memory bottleneck of inference. Each token requires $2 \\times n_{\\text{layers}} \\times n_{\\text{heads}} \\times d_{\\text{head}} \\times \\text{dtype_size}$ bytes of KV data. For LLaMA-70B, that is ~2.5 MB per token. With 100 concurrent requests at 2048 tokens each, the KV Cache alone needs 500 GB — far beyond any single GPU. How many requests you can batch and how long your sequences can be depends entirely on how efficiently you manage the KV Cache.\nPagedAttention applies OS virtual memory to eliminate memory fragmentation. Fixed-size blocks are allocated on demand, a block table provides logical-to-physical mapping, and CoW supports efficient sequence forking. Memory waste drops from 60-80% to under 5%, allowing the same GPU to batch 2-4x more requests.\nRadixAttention enables cross-request KV Cache reuse. A Radix Tree indexes all cached token prefixes; new requests automatically match their longest cached prefix and reuse the corresponding KV. Combined with LRU eviction and cache-aware scheduling, this dramatically reduces TTFT in shared-prefix workloads (system prompts, few-shot examples, multi-turn conversations).\nContinuous Batching drops scheduling granularity from batch level to iteration level. Finished requests are evicted and waiting requests are admitted at every decode step, pushing GPU utilization from 20-50% to 80-95%. Preemption provides a safety valve against OOM crashes under memory pressure.\nChunked Prefill prevents long prompts from blocking decode. Splitting prefill into smaller chunks with interleaved decode steps trades a modest TTFT increase for dramatically more stable TPOT. Chunk size (typically 512-1024 tokens) balances prefill efficiency against decode smoothness.\nThese techniques are not alternatives — they stack. Modern inference engines (SGLang, vLLM) simultaneously employ PagedAttention + Prefix Caching + Continuous Batching + Chunked Prefill. Each technique addresses a different dimension of efficiency, and together they deliver 2-10x throughput over naive implementations.\nCompanion Code The companion code for this article is in code/03-inference-sglang/:\nkv_cache_from_scratch.py — KV Cache and PagedAttention Block Manager built from scratch. Contains four parts: (1) demonstrating naive KV Cache waste; (2) full PagedAttention implementation with block allocation, deallocation, and CoW; (3) a mini decode loop using the block-managed KV cache; (4) head-to-head memory efficiency comparison of naive vs paged allocation. Run with: python kv_cache_from_scratch.py\ncontinuous_batching.py — Static batching vs continuous batching comparison simulation. Contains four parts: (1) throughput comparison between static and continuous batching; (2) ASCII scheduling timeline visualization; (3) preemption behavior under memory pressure; (4) scaling analysis across different batch sizes. Run with: python continuous_batching.py\nAll code runs in CPU mode (no GPU required), making it easy to study the core logic in any environment.\nReferences Efficient Memory Management for Large Language Model Serving with PagedAttention — Kwon et al., SOSP'23 (vLLM) — The original PagedAttention paper, bringing OS virtual memory concepts to KV Cache management. SGLang: Efficient Execution of Structured Language Model Programs — Zheng et al., 2024 (LMSYS) — The original paper introducing RadixAttention and cache-aware scheduling. Orca: A Distributed Serving System for Transformer-Based Generative Models — Yu et al., OSDI'22 — The pioneering work on continuous batching (iteration-level scheduling). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — Dao et al., NeurIPS'22 — The core attention kernel optimization used in inference engines. vLLM Documentation — docs.vllm.ai — Official vLLM documentation. SGLang Documentation — sgl-project.github.io — Official SGLang documentation. Sarathi-Serve: Chunked-Prefills for Efficient LLM Serving — Agrawal et al., 2024 — Systematic analysis of chunked prefill strategies. ","permalink":"https://mzf666.github.io/llm-infra/en/posts/03-inference-sglang/","summary":"A deep dive into PagedAttention and RadixAttention — understanding the core design of modern LLM inference engines.","title":"LLM Inference System Architecture (SGLang as Case Study)"},{"content":"Motivation If you have been following this series, you now understand GPU memory hierarchies, collective communication primitives, distributed parallelism strategies, and inference serving architectures. RLHF (Reinforcement Learning from Human Feedback) brings all of these concerns together in a single training pipeline — and adds several new ones.\nThe common misconception is that RLHF is \u0026ldquo;just add a reward model to your training loop.\u0026rdquo; In practice, RLHF requires running four large neural networks simultaneously, orchestrating a complex data flow between them, managing a generation phase that is fundamentally an inference problem embedded inside a training loop, and keeping weight copies synchronized across GPU groups. The algorithmic challenge of PPO is real, but the dominant engineering difficulty is systems design.\nA 7B-parameter RLHF setup consumes roughly 112 GB of GPU memory for weights and optimizer states alone — nearly 8x what supervised fine-tuning (SFT) requires for the same model. At 70B parameters, you need a minimum of 18 A100-80GB GPUs just to hold the weights, and realistically 64-128 GPUs for reasonable training throughput. This is not an algorithm you can prototype on a single GPU; it is a distributed systems problem from day one.\nThis article takes a systems perspective on RLHF. We start with the four-model architecture, walk through the PPO data flow step by step, quantify the memory and compute costs, and then examine how production frameworks — particularly verl (Volcano Engine RL) — solve these challenges through hybrid parallelism strategies. By the end, you will understand not just what RLHF does, but why it is hard to run efficiently at scale.\nPrerequisites GPU Memory Model and Distributed Communication Fundamentals (Article 1) A Panorama of Distributed Parallelism Strategies (Article 2) LLM Inference System Architecture (Article 3) Basic reinforcement learning concepts (Policy, Reward, PPO) The Four Models of RLHF RLHF is unique among training methods in that it requires four distinct models to be loaded and executed during every training iteration. Understanding their roles, update rules, and memory footprints is the first step to understanding why RLHF is a systems problem.\nActor (Policy Model) The Actor is the LLM you are actually training. It generates responses to prompts and is updated via PPO to maximize reward while staying close to its original behavior. In production, this is your LLaMA, Qwen, or Mistral checkpoint.\nThe Actor participates in two fundamentally different phases:\nGeneration phase: autoregressive sampling — this is an inference workload (memory-bound, benefits from KV cache and tensor parallelism). Training phase: forward pass to compute new log probabilities, followed by a backward pass to update weights via PPO — this is a standard training workload (compute-bound, benefits from FSDP). This dual nature is one of the core reasons RLHF is a systems challenge: the Actor needs different parallelism strategies for different phases.\nReference Model The Reference Model is a frozen copy of the Actor\u0026rsquo;s initial weights, taken before any RLHF training begins. It is never updated, but it must run a full forward pass on every batch to produce per-token log probabilities.\nIts purpose is to compute a KL divergence penalty that prevents the Actor from drifting too far from its original behavior. Without this constraint, the Actor quickly learns to \u0026ldquo;hack\u0026rdquo; the Reward Model — generating degenerate outputs that achieve high reward scores but are incoherent or repetitive. The KL penalty acts as an anchor, keeping the Actor\u0026rsquo;s distribution close to a known-good baseline.\nDespite being frozen, the Reference Model still consumes the same memory as a full model copy. It requires no optimizer states, but its parameters must be resident in GPU memory for forward passes.\nReward Model The Reward Model takes a (prompt, response) pair and outputs a scalar reward score. It is trained separately on human preference data — given pairs of responses where a human annotator has indicated which one is better, the model learns to assign higher scores to preferred outputs.\nArchitecturally, the Reward Model typically shares the same Transformer backbone as the base LLM, but replaces the language modeling head with a value head — a linear projection from the final hidden state to a single scalar:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class RewardModel(nn.Module): def __init__(self, vocab_size, d_model, ...): # Same Transformer backbone as the Actor self.token_embed = nn.Embedding(vocab_size, d_model) self.pos_embed = nn.Embedding(max_seq_len, d_model) self.blocks = nn.ModuleList([...]) self.ln_f = nn.LayerNorm(d_model) # Value head replaces the language modeling head self.value_head = nn.Linear(d_model, 1, bias=False) def forward(self, input_ids): # ... Transformer forward pass ... last_hidden = x[:, -1, :] # Sequence representation reward = self.value_head(last_hidden).squeeze(-1) # Scalar per sequence return reward The Reward Model is frozen during RLHF training — it only performs inference.\nCritic (Value Model) The Critic estimates $V(s)$ — the expected future reward at each token position in the response. This is essential for computing advantages via GAE (Generalized Advantage Estimation), which tell the PPO algorithm \u0026ldquo;how much better (or worse) was this action compared to what we expected?\u0026rdquo;\nUnlike the Reward Model, which produces one scalar per sequence, the Critic produces a per-token value estimate:\n1 2 3 4 5 6 7 8 9 class CriticModel(nn.Module): def __init__(self, vocab_size, d_model, ...): # Same backbone, but per-token output self.value_head = nn.Linear(d_model, 1, bias=False) def forward(self, input_ids): # ... Transformer forward pass ... values = self.value_head(x).squeeze(-1) # (batch, seq_len) return values The Critic is trained alongside the Actor during RLHF — it has its own optimizer and receives gradient updates. It is often initialized from the Reward Model\u0026rsquo;s weights, since both models estimate a form of \u0026ldquo;how good is this partial response.\u0026rdquo;\nMemory Breakdown The memory cost of running all four models simultaneously is what makes RLHF so resource-intensive. Here is the breakdown for FP16 weights:\nModel Weights Optimizer States Total ───────────────────────────────────────────────────────── Actor 2 * P 2 * 2 * P (Adam) 6P Critic 2 * P 2 * 2 * P (Adam) 6P Reference 2 * P 0 (frozen) 2P Reward Model 2 * P 0 (frozen) 2P ───────────────────────────────────────────────────────── TOTAL 8 * P 8 * P 16P P = number of parameters in bytes (param_count * 2 for FP16) For a 7B model:\nComponent Memory Per model (FP16 weights) 14 GB 4 models (weights only) 56 GB + Adam states for Actor + Critic 112 GB vs SFT (1 model + optimizer) ~28 GB RLHF requires roughly 4x the memory of SFT at the same model scale. At 70B, you need over 1,100 GB just for weights and optimizer states — a minimum of 14 A100-80GB GPUs, and realistically 64+ for reasonable throughput with activation memory and communication buffers.\nPPO Data Flow Every PPO iteration in RLHF consists of two distinct phases: a rollout phase (experience collection) and a training phase (policy update). Understanding this data flow is critical for reasoning about where system bottlenecks arise.\nPhase 1: Rollout (Experience Collection) The rollout phase collects \u0026ldquo;experience\u0026rdquo; — the data that the PPO algorithm will learn from. It involves all four models and proceeds in a strict sequential order:\nflowchart TD subgraph Rollout[\u0026#34;Rollout Phase (all inference, no gradients)\u0026#34;] S1[\u0026#34;**Step 1: GENERATION**\u0026lt;br/\u0026gt;Prompts → Actor (autoregressive) → Responses\u0026#34;] S1 --\u0026gt; S2 \u0026amp; S3a \u0026amp; S3b \u0026amp; S4 S2[\u0026#34;**Step 2: REWARD SCORING**\u0026lt;br/\u0026gt;(Prompt+Response) → Reward Model → Scalar reward\u0026#34;] S3a[\u0026#34;**Step 3a: ACTOR LOG PROBS**\u0026lt;br/\u0026gt;(Prompt+Response) → Actor → log probs π_θ\u0026#34;] S3b[\u0026#34;**Step 3b: REFERENCE LOG PROBS**\u0026lt;br/\u0026gt;(Prompt+Response) → Reference → log probs π_ref\u0026#34;] S4[\u0026#34;**Step 4: VALUE ESTIMATION**\u0026lt;br/\u0026gt;(Prompt+Response) → Critic → per-token V(s)\u0026#34;] S2 --\u0026gt; S5 S3a --\u0026gt; S5 S3b --\u0026gt; S5 S4 --\u0026gt; S5 S5[\u0026#34;**Step 5: ADVANTAGE COMPUTATION**\u0026lt;br/\u0026gt;(Rewards, KL penalties, Values) → GAE → Advantages\u0026#34;] end style S1 fill:#cce5ff,stroke:#007bff style S2 fill:#fff3cd,stroke:#ffc107 style S3a fill:#fff3cd,stroke:#ffc107 style S3b fill:#fff3cd,stroke:#ffc107 style S4 fill:#fff3cd,stroke:#ffc107 style S5 fill:#d4edda,stroke:#28a745 Notice that every step in the rollout phase is an inference workload. No gradients are computed; all four models run in torch.no_grad() mode. The generation step (Step 1) is particularly expensive because it is autoregressive — each token requires a full forward pass through the Actor.\nThe companion code demonstrates this pipeline:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # Step 1: Generate responses (inference — no gradients) with torch.no_grad(): full_ids = generate_responses(actor, prompt_ids, response_len, temperature) # Step 2-3: Score with Reward Model + compute KL with torch.no_grad(): rewards = reward_model(full_ids) old_log_probs = actor.get_log_probs(full_ids)[:, prompt_len - 1:] ref_log_probs = reference.get_log_probs(full_ids)[:, prompt_len - 1:] kl_penalties = kl_coeff * (old_log_probs - ref_log_probs) # Step 4-5: Compute values and GAE advantages with torch.no_grad(): values = critic(full_ids)[:, prompt_len - 1:-1] advantages, returns = compute_advantages_gae(rewards, kl_penalties, values) The KL Penalty The KL divergence between the Actor and Reference distributions is the critical safety mechanism of RLHF. Without it, the Actor degenerates within a few hundred iterations. The per-token KL penalty is computed as:\n$$D_{KL}(\\pi_\\theta | \\pi_\\text{ref}) \\approx \\log \\pi_\\theta(a_t | s_t) - \\log \\pi_\\text{ref}(a_t | s_t)$$\nThis is scaled by a coefficient $\\beta$ (typically 0.01 - 0.2) and subtracted from the reward:\n$$r_t^{\\text{adjusted}} = r_t - \\beta \\cdot D_{KL}$$\nThe coefficient $\\beta$ is often adaptively tuned during training: if the KL divergence exceeds a target threshold, $\\beta$ increases to pull the Actor back; if KL is below the target, $\\beta$ decreases to allow more exploration.\nGeneralized Advantage Estimation (GAE) GAE computes per-token advantages that tell PPO how to update the policy. The advantage $\\hat{A}_t$ at token position $t$ answers: \u0026ldquo;how much better was the action taken here compared to the Critic\u0026rsquo;s expectation?\u0026rdquo;\n$$\\hat{A}t = \\sum{l=0}^{T-t} (\\gamma \\lambda)^l \\delta_{t+l}$$\nwhere $\\delta_t = r_t + \\gamma V(s_{t+1}) - V(s_t)$ is the TD error.\nThe $\\lambda$ parameter controls a bias-variance tradeoff:\n$\\lambda = 1$: high variance, low bias (equivalent to Monte Carlo returns) $\\lambda = 0$: low variance, high bias (one-step TD) $\\lambda = 0.95$: the standard choice, a good balance In practice, only the last token of the response receives the Reward Model\u0026rsquo;s scalar reward. All other tokens receive only the KL penalty. GAE then propagates this terminal reward backward through the sequence:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 # Per-token rewards: KL penalty at every position, reward at the last token per_token_rewards = -kl_penalties.clone() per_token_rewards[:, -1] += rewards # Sequence reward at final token # GAE: propagate rewards backward through time advantages = torch.zeros_like(values) last_gae = torch.zeros(B, device=device) for t in reversed(range(T)): next_value = values[:, t + 1] if t \u0026lt; T - 1 else torch.zeros(B, device=device) delta = per_token_rewards[:, t] + gamma * next_value - values[:, t] last_gae = delta + gamma * lam * last_gae advantages[:, t] = last_gae returns = advantages + values # Training target for the Critic Phase 2: PPO Update (Training) With advantages computed, the training phase updates the Actor and Critic. PPO\u0026rsquo;s key innovation is that you can reuse the same rollout data for multiple gradient updates, as long as you clip the policy ratio to prevent the Actor from changing too drastically in one step.\nActor Loss (Clipped Surrogate Objective):\n$$L^{CLIP}(\\theta) = \\mathbb{E}\\left[\\min\\left(r_t(\\theta)\\hat{A}_t, ; \\text{clip}(r_t(\\theta), 1-\\epsilon, 1+\\epsilon)\\hat{A}_t\\right)\\right]$$\nwhere $r_t(\\theta) = \\frac{\\pi_\\theta(a_t | s_t)}{\\pi_{\\theta_\\text{old}}(a_t | s_t)} = \\exp(\\log \\pi_\\theta - \\log \\pi_{\\theta_\\text{old}})$ is the probability ratio.\nThe clipping mechanism is elegant in its simplicity:\nWhen $\\hat{A}_t \u0026gt; 0$ (good action): we want to increase $r_t$, but clipping caps it at $1 + \\epsilon$, preventing overcommitment. When $\\hat{A}_t \u0026lt; 0$ (bad action): we want to decrease $r_t$, but clipping caps it at $1 - \\epsilon$, preventing overcorrection. 1 2 3 4 5 6 7 8 9 10 11 12 def ppo_actor_loss(actor, full_ids, prompt_len, old_log_probs, advantages, clip_eps): new_log_probs = actor.get_log_probs(full_ids)[:, prompt_len - 1:] ratio = torch.exp(new_log_probs - old_log_probs.detach()) # Normalize advantages for stability adv = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Clipped surrogate loss surr1 = ratio * adv surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv loss = -torch.min(surr1, surr2).mean() return loss, ratio Critic Loss: a straightforward MSE between the Critic\u0026rsquo;s value predictions and the GAE returns (advantages + old values):\n$$L_\\text{critic} = \\frac{1}{T}\\sum_t \\left(V_\\phi(s_t) - \\hat{R}_t\\right)^2$$\nTraining Loop: for each PPO iteration, we typically run 2-4 gradient update epochs on the same rollout data:\n1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 for epoch in range(ppo_epochs): # Actor update actor_optim.zero_grad() a_loss, ratio = ppo_actor_loss(actor, full_ids, prompt_len, old_log_probs, advantages, clip_eps) a_loss.backward() torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=1.0) actor_optim.step() # Critic update critic_optim.zero_grad() c_loss = ppo_critic_loss(critic, full_ids, prompt_len, returns) c_loss.backward() torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0) critic_optim.step() Only the Actor and Critic receive gradient updates. The Reference and Reward Model remain frozen throughout.\nWhy RLHF Is a Systems Problem With the data flow clear, let us now enumerate the specific systems challenges that distinguish RLHF from standard SFT training.\nChallenge 1: 4x Memory Footprint As quantified earlier, RLHF needs roughly 4x the GPU memory of SFT at the same model scale. This is not just a matter of buying more GPUs — it fundamentally changes the parallelism strategy. For a 7B model:\nSFT: fits on a single A100-80GB with basic mixed precision. No sharding required. RLHF: requires at minimum 4 A100-80GB GPUs with FSDP sharding across all four models. At 70B, the gap widens further. SFT can work with 4-8 GPUs using FSDP; RLHF needs 64+ GPUs with careful placement and scheduling.\nChallenge 2: Compute Heterogeneity The rollout phase and training phase have fundamentally different compute characteristics:\nPhase Nature Bottleneck Best Parallelism ───────────────────────────────────────────────────────────────────────── Generation Inference Memory bandwidth Tensor Parallelism (autoregressive) (sequential) (KV cache I/O) (low latency) Reward scoring Inference Memory bandwidth Tensor Parallelism Reference log probs Inference Memory bandwidth Tensor Parallelism Actor training Training Compute (matmul) FSDP / DDP Critic training Training Compute (matmul) FSDP / DDP A naive approach — using the same parallelism strategy for both phases — leaves significant performance on the table. Generation with FSDP requires an all-gather of full weights before every layer, which is wasteful for the small batch sizes typical of autoregressive decoding. Training with TP wastes communication bandwidth on all-reduce operations that are unnecessary when you can shard optimizer states instead.\nThe optimal strategy is to switch parallelism modes between phases: use TP for generation (low latency per token) and FSDP for training (memory-efficient parameter sharding). This is exactly what verl\u0026rsquo;s hybrid engine does, and it is the key architectural insight of modern RLHF systems.\nChallenge 3: Complex Data Dependencies The data flow through four models creates a strict dependency chain:\nflowchart LR Gen[\u0026#34;Generation\u0026#34;] --\u0026gt; Score[\u0026#34;Reward Scoring\u0026#34;] Gen --\u0026gt; |\u0026#34;needs Actor weights\u0026#34;| Gen Score --\u0026gt; KL[\u0026#34;KL Computation\u0026#34;] KL --\u0026gt; |\u0026#34;needs Ref weights\u0026#34;| KL KL --\u0026gt; GAE GAE --\u0026gt; |\u0026#34;needs Critic values\u0026#34;| GAE GAE --\u0026gt; PPO[\u0026#34;PPO Update\u0026#34;] style Gen fill:#cce5ff,stroke:#007bff style Score fill:#fff3cd,stroke:#ffc107 style KL fill:#fff3cd,stroke:#ffc107 style GAE fill:#d4edda,stroke:#28a745 style PPO fill:#d4edda,stroke:#28a745 You cannot score responses before they are generated. You cannot compute advantages before you have rewards, KL penalties, and Critic values. And you cannot start the next PPO iteration until the Actor weights have been updated and synchronized to the generation engine.\nIn a distributed setting, this means some GPU groups are idle while others work. If the Reward Model runs on separate GPUs from the Actor, those GPUs sit idle during generation and training. If the Actor generates on the same GPUs it trains on, you need to manage memory carefully — generation may require gathering full model weights, which temporarily doubles the memory requirement.\nChallenge 4: Weight Synchronization After each PPO update to the Actor, the generation engine must use the new weights for the next rollout. In a colocated setup (all models on the same GPUs), this happens automatically. But in a separated setup, or when using a dedicated inference engine like vLLM for generation, the updated weights must be explicitly transferred.\nFor a 7B model in FP16, this is a 14 GB transfer — trivial over NVLink but costly over PCIe or network interconnects. For 70B models, it is 140 GB, which at 25 GB/s (PCIe Gen4 x16) takes over 5 seconds. Systems must overlap this transfer with other computation or use architectural choices (colocation) to avoid it entirely.\nCompute Cost Comparison Putting it all together, here is how the compute cost of one RLHF iteration compares to one SFT step:\nOperation SFT RLHF Ratio ────────────────────────────────────────────────────────────── Forward pass (Actor) 1x 2x 2x Backward pass (Actor) 1x 1x 1x Autoregressive generation 0 N tokens Nx Forward pass (Reference) 0 1x +1x Forward pass (Reward Model) 0 1x +1x Forward pass (Critic) 0 2x +2x Backward pass (Critic) 0 1x +1x GAE computation 0 1x +1x ────────────────────────────────────────────────────────────── TOTAL (approximate) ~2x ~10-16x 5-8x The autoregressive generation step dominates. Without KV caching, generating $T$ response tokens requires $T$ sequential forward passes through the Actor. Even with KV caching, each step is memory-bandwidth-bound with poor compute utilization. This single phase can account for 50-70% of total RLHF iteration time.\nThe verl Architecture verl (Volcano Engine RL) is an open-source RLHF framework that addresses the systems challenges above through a hybrid engine design. Its core insight is that you can colocate all four models on the same GPU group and dynamically switch between parallelism strategies depending on the current phase.\nDesign Principles verl\u0026rsquo;s architecture rests on three key ideas:\nColocation over separation: all four models share the same GPUs, eliminating cross-group data transfer. Hybrid parallelism: FSDP for training phases, TP for generation phases, with seamless switching between modes. Worker-based scheduling: a central controller dispatches work to model-specific workers, managing the phase transitions. flowchart TD subgraph GPU[\u0026#34;GPU Group (N GPUs) — verl Colocated Mode\u0026#34;] subgraph Models[\u0026#34;All 4 Models (FSDP sharded)\u0026#34;] direction LR Actor[\u0026#34;Actor\u0026lt;br/\u0026gt;(FSDP sharded)\u0026#34;] Critic[\u0026#34;Critic\u0026lt;br/\u0026gt;(FSDP sharded)\u0026#34;] Reward[\u0026#34;Reward\u0026lt;br/\u0026gt;(FSDP sharded)\u0026#34;] Reference[\u0026#34;Reference\u0026lt;br/\u0026gt;(FSDP sharded)\u0026#34;] end subgraph Phases[\u0026#34;Phase Transitions\u0026#34;] Train[\u0026#34;**Training mode:**\u0026lt;br/\u0026gt;FSDP (all-gather params,\u0026lt;br/\u0026gt;reduce-scatter grads)\u0026#34;] Gen[\u0026#34;**Generation mode:**\u0026lt;br/\u0026gt;Gather full weights →\u0026lt;br/\u0026gt;Tensor Parallelism\u0026#34;] end Models --\u0026gt; Phases end style Actor fill:#d4edda,stroke:#28a745 style Critic fill:#d4edda,stroke:#28a745 style Reward fill:#fff3cd,stroke:#ffc107 style Reference fill:#fff3cd,stroke:#ffc107 style Train fill:#d4edda,stroke:#28a745 style Gen fill:#cce5ff,stroke:#007bff The Hybrid Engine: FSDP-TP Mode Switching The most important architectural decision in verl is how it switches between FSDP and TP modes for the Actor model. The challenge: FSDP shards parameters across GPUs by flattening and partitioning them, while TP shards parameters by slicing weight matrices along specific dimensions (column-parallel for the first linear layer, row-parallel for the second).\nverl handles this by:\nFSDP mode (training): parameters are sharded using PyTorch\u0026rsquo;s FSDP (ZeRO Stage 3). Each GPU holds $1/N$ of the flattened parameter tensor. During forward/backward passes, parameters are all-gathered on demand and gradients are reduce-scattered after the backward pass.\nTP mode (generation): FSDP shards are gathered to reconstruct full parameters, then resharded along TP dimensions. For a weight matrix $W \\in \\mathbb{R}^{d \\times 4d}$ in an FFN layer, GPU $i$ holds columns $W[:, i \\cdot 4d/N : (i+1) \\cdot 4d/N]$. Generation proceeds with standard TP communication (all-reduce after each layer).\nTransition cost: the FSDP-to-TP reshape requires an all-gather of the full parameters followed by a local slice. For a 7B model, this is a 14 GB all-gather — roughly 1-2 ms over NVLink, which is negligible compared to the generation time.\nflowchart LR subgraph FSDP[\u0026#34;FSDP Shards\u0026#34;] direction TB S0[\u0026#34;GPU 0: shard_0\u0026#34;] S1[\u0026#34;GPU 1: shard_1\u0026#34;] S2[\u0026#34;GPU 2: shard_2\u0026#34;] S3[\u0026#34;GPU 3: shard_3\u0026#34;] end FSDP --\u0026gt; |\u0026#34;all-gather\u0026#34;| Full[\u0026#34;full_W\u0026#34;] Full --\u0026gt; |\u0026#34;slice\u0026#34;| TP subgraph TP[\u0026#34;TP Slices (for generation)\u0026#34;] direction TB T0[\u0026#34;GPU 0: W_col_0\u0026#34;] T1[\u0026#34;GPU 1: W_col_1\u0026#34;] T2[\u0026#34;GPU 2: W_col_2\u0026#34;] T3[\u0026#34;GPU 3: W_col_3\u0026#34;] end Note[\u0026#34;After generation, discard TP slices\u0026lt;br/\u0026gt;and return to FSDP shards\u0026#34;] TP --\u0026gt; Note style FSDP fill:#d4edda,stroke:#28a745 style Full fill:#fff3cd,stroke:#ffc107 style TP fill:#cce5ff,stroke:#007bff Weight Update Flow The weight update cycle in verl for one PPO iteration looks like this:\nGather Actor weights (FSDP all-gather → TP slice) for generation Generate responses using TP-sharded Actor Discard TP slices, return to FSDP shards Forward passes through Reward Model, Reference, Critic (FSDP mode, no TP needed for non-autoregressive inference) Compute advantages (local computation, no communication) PPO update on Actor and Critic (FSDP mode: all-gather for forward, reduce-scatter for backward) Actor weights are now updated in FSDP shards — no explicit sync needed because generation will re-gather them next iteration The key optimization here is that weight synchronization is free in the colocated design. The generation engine and training engine share the same parameter storage, so after the PPO update, the next generation phase simply re-gathers the (now updated) FSDP shards.\nResource Scheduling verl uses a single-controller, multi-worker architecture. The controller orchestrates the PPO iteration by sending commands to workers:\nflowchart TD subgraph Controller[\u0026#34;Controller (single process) — per PPO iteration\u0026#34;] S1[\u0026#34;1. Send \u0026#39;generate\u0026#39; to Actor workers\u0026#34;] S2[\u0026#34;2. Send \u0026#39;score\u0026#39; to Reward workers\u0026#34;] S3[\u0026#34;3. Send \u0026#39;log_probs\u0026#39; to Reference workers\u0026#34;] S4[\u0026#34;4. Send \u0026#39;value\u0026#39; to Critic workers\u0026#34;] S5[\u0026#34;5. Compute advantages (local)\u0026#34;] S6[\u0026#34;6. Send \u0026#39;train\u0026#39; to Actor workers\u0026lt;br/\u0026gt;(multiple PPO epochs)\u0026#34;] S7[\u0026#34;7. Send \u0026#39;train\u0026#39; to Critic workers\u0026lt;br/\u0026gt;(multiple PPO epochs)\u0026#34;] S1 --\u0026gt; S2 --\u0026gt; S3 --\u0026gt; S4 --\u0026gt; S5 --\u0026gt; S6 --\u0026gt; S7 end style S1 fill:#cce5ff,stroke:#007bff style S2 fill:#fff3cd,stroke:#ffc107 style S3 fill:#fff3cd,stroke:#ffc107 style S4 fill:#fff3cd,stroke:#ffc107 style S5 fill:#d4edda,stroke:#28a745 style S6 fill:#d4edda,stroke:#28a745 style S7 fill:#d4edda,stroke:#28a745 In colocated mode, each worker manages multiple model roles on the same GPU. The controller ensures that only one model is active at a time, preventing memory contention. In separated mode, each worker owns a specific model and the controller handles data routing between groups.\nverl supports micro-batching within each phase to handle cases where the full batch does not fit in GPU memory. For example, if the generation batch size is 512 but each GPU can only hold 64 sequences for generation (due to KV cache memory), verl splits the batch into 8 micro-batches and processes them sequentially.\nDistributed RLHF Strategies At scale, RLHF requires distributed execution across many GPUs. There are two fundamental approaches to placing the four models, each with distinct tradeoffs.\nColocated Strategy In the colocated approach, all four models reside on the same set of GPUs, each sharded via FSDP:\nflowchart TD subgraph Colocated[\u0026#34;Colocated Deployment (8 GPUs)\u0026#34;] Config[\u0026#34;GPU 0-7: Actor (FSDP) + Critic (FSDP) + Reward (FSDP) + Reference (FSDP)\u0026#34;] subgraph Timeline[\u0026#34;Timeline of one PPO iteration\u0026#34;] direction LR P1[\u0026#34;Generation\u0026lt;br/\u0026gt;(Actor, TP)\u0026#34;] P2[\u0026#34;Score\u0026lt;br/\u0026gt;(RM)\u0026#34;] P3[\u0026#34;Ref LP\u0026lt;br/\u0026gt;(Ref)\u0026#34;] P4[\u0026#34;Values\u0026lt;br/\u0026gt;(Crit)\u0026#34;] P5[\u0026#34;Actor PPO\u0026lt;br/\u0026gt;(FSDP)\u0026#34;] P6[\u0026#34;Critic PPO\u0026lt;br/\u0026gt;(FSDP)\u0026#34;] P1 --\u0026gt; P2 --\u0026gt; P3 --\u0026gt; P4 --\u0026gt; P5 --\u0026gt; P6 end Config --\u0026gt; Timeline Note[\u0026#34;All 8 GPUs active in every phase\u0026lt;br/\u0026gt;(different model active each phase)\u0026#34;] end style P1 fill:#cce5ff,stroke:#007bff style P2 fill:#fff3cd,stroke:#ffc107 style P3 fill:#fff3cd,stroke:#ffc107 style P4 fill:#fff3cd,stroke:#ffc107 style P5 fill:#d4edda,stroke:#28a745 style P6 fill:#d4edda,stroke:#28a745 Advantages:\nNo cross-group data transfer — responses, rewards, and log probabilities stay on the same GPUs All GPUs are utilized in every phase Simple scheduling — one sequential pipeline Disadvantages:\nPeak memory pressure is high — all four models must have their FSDP shards resident simultaneously Cannot independently tune parallelism per model (e.g., more TP for Actor, less for Critic) If one model is much larger than the others, memory allocation is unbalanced Separated Strategy In the separated approach, each model gets its own dedicated GPU group:\nflowchart TD subgraph Separated[\u0026#34;Separated Deployment (32 GPUs)\u0026#34;] subgraph Groups[\u0026#34;GPU Allocation\u0026#34;] direction LR AG[\u0026#34;GPUs 0-15: Actor\u0026lt;br/\u0026gt;(FSDP + TP hybrid)\u0026#34;] CG[\u0026#34;GPUs 16-23: Critic\u0026lt;br/\u0026gt;(FSDP)\u0026#34;] RG[\u0026#34;GPUs 24-27: Reward\u0026lt;br/\u0026gt;(TP, inference only)\u0026#34;] RefG[\u0026#34;GPUs 28-31: Reference\u0026lt;br/\u0026gt;(TP, inference only)\u0026#34;] end subgraph Timeline[\u0026#34;Timeline of one PPO iteration\u0026#34;] direction TB AT[\u0026#34;Actor GPUs: Generation → idle → PPO Training\u0026#34;] CT[\u0026#34;Critic GPUs: idle → Value compute → PPO Training\u0026#34;] RT[\u0026#34;Reward GPUs: idle → Score → idle\u0026#34;] RefT[\u0026#34;Ref GPUs: idle → Log probs → idle\u0026#34;] end Groups --\u0026gt; Timeline end style AG fill:#d4edda,stroke:#28a745 style CG fill:#d4edda,stroke:#28a745 style RG fill:#fff3cd,stroke:#ffc107 style RefG fill:#fff3cd,stroke:#ffc107 style AT fill:#cce5ff,stroke:#007bff Advantages:\nEach model has full GPU memory available — no sharing pressure Can use different parallelism strategies optimized for each model\u0026rsquo;s workload The Actor can use more GPUs for faster generation Disadvantages:\nSignificant GPU idle time — Reward and Reference GPUs sit idle during generation and training Data must be transferred between GPU groups (responses, rewards, log probabilities) More complex scheduling and synchronization Communication Patterns The communication requirements differ significantly between phases:\nPhase Communication Pattern ───────────────────────────────────────────────────────────────────── Generation (TP) All-reduce per layer Latency-bound Generation (FSDP→TP) All-gather for weight reshape One-time cost Actor training (FSDP) All-gather + reduce-scatter Bandwidth-bound Critic training (FSDP) All-gather + reduce-scatter Bandwidth-bound Reward scoring Broadcast prompts+responses One-time cost Reference log probs Broadcast prompts+responses One-time cost Weight sync (separated) Broadcast/all-gather new params After each iter In the colocated strategy, the dominant communication cost is the FSDP all-gather during generation (to reconstruct full weights for TP) and the all-gather + reduce-scatter during training. In the separated strategy, the additional cost of transferring rollout data between GPU groups can be substantial — for a batch of 512 sequences of length 2048 in FP16, the response tensor alone is 2 GB.\nHybrid Approaches Modern RLHF systems increasingly use hybrid strategies that combine elements of both approaches:\nverl\u0026rsquo;s default: colocated with FSDP↔TP switching. All models share GPUs but use different parallelism modes per phase. Partial separation: the Actor gets its own GPU group for generation (where it needs maximum memory for KV cache), but shares GPUs with other models during training. Asymmetric allocation: allocate more GPUs to the Actor (which dominates compute time) and fewer to the frozen models (which only perform inference). The right strategy depends on model size, cluster topology, and the relative cost of generation vs. training. For models up to 13B, colocation is usually sufficient on 8-16 GPUs. For 70B+, some form of separation or asymmetric allocation becomes necessary to manage memory.\nKey Takeaways RLHF needs four models simultaneously: Actor (generates responses, trained via PPO), Critic (estimates value, trained alongside Actor), Reward Model (scores responses, frozen), Reference (KL anchor, frozen). This is roughly 10x the memory of SFT when including optimizer states for Actor and Critic.\nThe generation phase is an inference problem inside a training loop. Autoregressive sampling needs KV caching, tensor parallelism, and continuous batching for efficiency — but it happens between gradient updates. This dual nature is why RLHF requires a \u0026ldquo;hybrid engine\u0026rdquo; that can switch between inference and training parallelism modes.\nPPO stabilizes training through two mechanisms: the KL penalty prevents reward hacking by keeping the Actor close to the Reference distribution, and the clipping objective prevents catastrophic policy updates by bounding the probability ratio to $[1-\\epsilon, 1+\\epsilon]$.\nThe core systems challenge is data flow orchestration. Each PPO iteration passes data through all four models in a strict dependency chain. In distributed settings, this means careful model placement, data routing between GPU groups, and scheduling to minimize idle time.\nverl\u0026rsquo;s key insight: colocate models and reshape parallelism per phase. Use FSDP (ZeRO-3) for memory-efficient training, but gather weights and switch to TP for fast generation. This avoids the tradeoff between separated (idle GPUs) and naive colocated (memory pressure) strategies.\nAt scale (70B+), RLHF requires 64-128 GPUs with careful parallelism choices. The generation phase dominates wall-clock time (50-70% of each iteration), making inference optimization — KV caching, continuous batching, efficient TP — just as important as training optimization.\nCompanion Code The companion code for this article is located at code/04-rlhf-system/:\nminimal_rlhf.py — A self-contained RLHF training loop implementing all four models, the PPO rollout pipeline, GAE advantage estimation, and the clipped surrogate objective. Uses small model dimensions (d_model=256) so everything runs on CPU, but the architecture and data flow are identical to production RLHF systems. Run it with:\n1 2 cd code/04-rlhf-system python minimal_rlhf.py The script runs four demonstrations:\nFour-model architecture: instantiates all models and displays parameter counts and memory estimates at production scales. PPO data generation: walks through the rollout phase step by step, showing data shapes and model interactions. PPO training: runs multiple PPO iterations, reporting reward, loss, KL divergence, and clip fraction. System challenges: prints the complete data flow diagram, compute cost comparison, and distributed strategy analysis. References Ouyang, L., et al. \u0026ldquo;Training language models to follow instructions with human feedback.\u0026rdquo; NeurIPS 2022. — The InstructGPT paper that established the RLHF pipeline. Schulman, J., et al. \u0026ldquo;Proximal Policy Optimization Algorithms.\u0026rdquo; arXiv:1707.06347, 2017. — The PPO algorithm. Sheng, Y., et al. \u0026ldquo;HybridFlow: A Flexible and Efficient RLHF Framework.\u0026rdquo; arXiv:2409.19256v2, 2024. — The verl system paper. Zheng, L., et al. \u0026ldquo;SGLang: Efficient Execution of Structured Language Model Programs.\u0026rdquo; arXiv:2312.07104, 2023. — Inference engine used by verl for generation. Schulman, J., et al. \u0026ldquo;High-Dimensional Continuous Control Using Generalized Advantage Estimation.\u0026rdquo; ICLR 2016. — The GAE algorithm. Rajbhandari, S., et al. \u0026ldquo;ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.\u0026rdquo; SC 2020. — The FSDP/ZeRO foundation used by verl. Stiennon, N., et al. \u0026ldquo;Learning to summarize from human feedback.\u0026rdquo; NeurIPS 2020. — Early RLHF application to summarization. ","permalink":"https://mzf666.github.io/llm-infra/en/posts/04-rlhf-system/","summary":"From the four-model RLHF architecture to verl\u0026rsquo;s system design — understanding why RLHF is fundamentally a systems problem.","title":"Introduction to RLHF System Design"}]