verl · 训练机制笔记

batch size 是怎么被切开的

从「按条等分」到「按 token 动态打包」——一次讲清 verl 训练侧两条 micro-batch 切分路径,以及为什么 dynamic 模式更容易踩到 NCCL hang。

verl-project / verl · 原文:Performance Tuning Guide ↗

TL;DR

verl 的训练侧把一个 mini-batch 切成 micro-batch 时有两条路径:默认按 micro_batch_size_per_gpu 等条数切,开了 use_dynamic_bsz=True 后改为按 max_token_len_per_gpu 等 token 数动态打包。后者用 Karmarkar-Karp 算法做 partition,并在 DP 维度用 all_reduce(MAX) 把 micro-batch 数量拉齐——这两步是它能跑通的关键,也是它最容易引发 NCCL hang 的根源。

2
条切分路径,由 use_dynamic_bsz 二选一
3
dynamic 模式的关键步骤:算切数 → DP 同步 → KK 分桶
5+
由 dynamic batch 直接/间接引发的已知 hang issue
01

先把三层 batch 概念分清楚

verl 里 "batch size" 这个词在不同位置含义不同。在讨论怎么切之前,先看清数据是怎么从一个 step 流到一张 GPU 的。

train_batch_size
↓ 切成若干个 mini-batch(PPO 多个 epoch 复用)
ppo_mini_batch
KK 第 1 层 _balance_batch 全局重排,让 chunk 后各 rank 总 token 接近
per-GPU shard
KK 第 2 层 dynamic_bsz 按 token 预算切 micro-batch(仅 use_dynamic_bsz=True
micro-batch
verl 在数据流上做了两层 KK 平衡:第 1 层在 single controller 上跨 DP rank 重排(默认开),第 2 层在每个 rank 内切 micro-batch(要开 use_dynamic_bsz)。本文会分别拆解这两层的作用和局限。
02

路径 A:静态切分(默认)

关键参数:ppo_micro_batch_size_per_gpu / log_prob_micro_batch_size_per_gpu。逻辑非常简单——按 样本条数 等分。

# verl/workers/engine/utils.py
if use_dynamic_bsz:
    ...
else:
    total_data_size = len(data)
    micro_batch_size_per_gpu = data["micro_batch_size_per_gpu"]
    assert total_data_size % (force_group_size * micro_batch_size_per_gpu) == 0
    micro_batches = tu.chunk_tensordict(
        data, total_data_size // (micro_batch_size_per_gpu * force_group_size)
    )
verl/workers/engine/utils.py:88–95

特性

  • 条数固定:每个 micro-batch 恰好 micro_batch_size_per_gpu 条样本,无论它们多长。
  • 显存不可控:极端情况一个 micro-batch 全是长样本,token 总数远超预期,OOM。
  • DP 间天然对齐:每张卡的 micro-batch 数完全一致,FSDP/Megatron 的集合通信节奏天然同步。
  • 失衡浪费算力:短样本 padding 到 batch 里最长那条,长样本拉低吞吐,GPU 利用率低。
03

路径 B:动态切分(use_dynamic_bsz=True)

关键参数:ppo_max_token_len_per_gpu。不再按条数切,而是给每张卡一个 token 预算,由算法去打包。

入口代码

# verl/workers/engine/utils.py — prepare_micro_batches()
if use_dynamic_bsz:
    assert "max_token_len_per_gpu" in data.keys()
    max_token_len_per_gpu = data["max_token_len_per_gpu"]
    max_token_len = max_token_len_per_gpu * sp_size   # SP 倍率
    micro_batches, batch_idx_list = rearrange_micro_batches(
        data, max_token_len=max_token_len,
        dp_group=dp_group,
        same_micro_num_in_dp=same_micro_num_in_dp,   # 关键
        ...
    )
verl/workers/engine/utils.py:74–87

三步切分流程

步骤 1:算每张卡至少要切几个 micro-batch

total_seqlen     = attention_mask.sum()
num_micro_batches = ceil(total_seqlen / max_token_len)

步骤 2:DP 维度同步 micro-batch 数(这步是 dynamic 模式独有的关键

if dist.is_initialized() and same_micro_num_in_dp and dp_group is not None:
    num_micro_batches = torch.tensor([num_micro_batches], device=...)
    dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
    num_micro_batches = num_micro_batches.cpu().item()
verl/utils/seqlen_balancing.py:402–405

步骤 3:用 Karmarkar-Karp 算法把样本按工作量平均分进 num_micro_batches 个桶

# 工作量估算:模拟 attention FLOPs ≈ 12·h²·L + 2·h·L²
workload = 24576 * seqlen + seqlen ** 2
partitions = karmarkar_karp(workload, k_partitions=num_micro_batches, equal_size=False)
verl/utils/seqlen_balancing.py:46 + get_seqlen_balanced_partitions()
为什么 all_reduce(MAX) 不可省:DP 各 rank 拿到的样本长度分布不同 → 各自算出的 num_micro_batches 不一样 → FSDP forward 循环里调用 all_gather 的次数不一样 → 集合通信永远配不上对,直接 NCCL hang。verl 用 MAX 把所有 rank 拉齐到全局最大值,短的 rank 通过把样本切成更稀疏的桶来对齐(不是补空 padding)。
04

先于 dynamic_bsz:DP 间的全局 KK 重排

第 03 节的 dynamic_bsz 只管 rank 内的 micro-batch 切分。那 rank 之间的总 token 数谁负责?答案是:在 dynamic_bsz 进各 rank 之前,verl 在 single controller 上先做了一次全局 KK 重排,叫 _balance_batch。由 trainer.balance_batch=True 默认开启。

# verl/trainer/ppo/ray_trainer.py — _balance_batch()
global_seqlen_lst = batch.batch["attention_mask"].sum(-1)
workload_lst = calculate_workload(global_seqlen_lst)
dp_size = self._get_dp_size(self.actor_rollout_wg, "actor")

# equal_size=True:每个 rank 拿到等条数样本
global_partition_lst = get_seqlen_balanced_partitions(
    workload_lst, k_partitions=dp_size, equal_size=True
)

# 把 batch 按 KK 给的分组重排,再 chunk 出去
global_idx = torch.tensor([j for partition in global_partition_lst for j in partition])
batch.reorder(global_idx)
verl/trainer/ppo/ray_trainer.py:1102–1170

这层和 dynamic_bsz 的关键区别

  • equal_size=True:每个 rank 必须拿到等条数样本(PPO 后续 mini-batch 数对齐需要),KK 在这个硬约束下做平衡。
  • 只重排顺序,不重切样本:原样本不变,只是 chunk 时被分到不同 rank。
  • 对 workload 而非 token 数平衡:用 12·h²·L + 2·h·L² 估算 attention FLOPs,让长样本被特别"重视"。
  • 每个 mini-batch 跑一次,rollout 完成后、进 forward 前。

演示:16 条样本下发到 4 个 DP rank

同一组 16 条不同长度样本,按原顺序 chunk vs 经过 _balance_batch 重排后再 chunk。

rank 数 4
每 rank 样本数 4
最大 rank token
最小 rank token
极差比
限制equal_size=True 让 KK 不能完全自由,比如「单条特别长的样本」总会落到某个 rank 上、把那个 rank 的 token 总数推高。这是为什么 log 里 balanced_max - balanced_min 不会真的变成 0,只是从原始的 ~10% 极差压到 ~3-5%。

log 里 global_seqlen/min,max未重排时的统计;global_seqlen/balanced_min,balanced_maxKK 重排后的统计。在自己的训练 metric 里直接对比这两组数字就能看 _balance_batch 在你的数据上有多大效果。

05

同一组样本,两种 micro-batch 切法的对比

假设这张 GPU 拿到 8 条样本,token 长度依次为 。共 token。

静态:micro_batch_size_per_gpu = 2

按 batch 顺序每 2 条切一个 micro-batch。token 数完全失衡。

桶数
最大桶
最小桶
极差比

动态:max_token_len_per_gpu = 13312

按 token 预算切,KK 算法做平衡,每桶 token 数接近预算。

桶数
最大桶
最小桶
极差比

桶里的色块代表一条样本,宽度正比于该样本 token 数。预算线表示 max_token_len_per_gpu;超过预算的桶会被高亮(动态模式不会出现,静态可能会 OOM)。

06

DP 维度同步:从「各算各的」到「整齐划一」

动态模式下,4 张卡各自算出来的 num_micro_batches 大概率不一样。verl 用 all_reduce(MAX) 把它们拉齐,否则 FSDP 集合通信对不上号会直接 hang。点击按钮看演示。

每个 rank 按自己的样本算出 num_micro_batches,目前不同步。
关键事实澄清:「拉齐」靠的不是补 dummy padding,而是让短的 rank 把自己手里的样本切成更多、更稀疏的桶——代码里的硬约束是 assert num_micro_batches ≤ num_groups,每个桶至少 1 条真实样本,没有空 batch。短 rank 后续 forward 循环里跑的是稀疏的真样本而非 padding。
遗留隐患:verl 在进 dynamic_bsz 之前其实还有一层全局 KK 平衡(_balance_batch,由 trainer.balance_batch=True 默认开),会先重排样本让各 rank 总 token 数尽量接近(log 里 global_seqlen/balanced_min/max 就是它的产物)。但它有 equal_size=True 的硬约束(每 rank 等条数),且只平衡总和不平衡每个 micro-batch 的工作量。所以单条特别长的样本仍可能让某 rank 的某个 micro-batch 慢得多,触发 NCCL watchdog timeout(issue #5750 的根因之一)。
07

速查:哪些参数随 use_dynamic_bsz 开关生效

遇到「我设了这个参数怎么没用」的时候回来看一眼。

参数 use_dynamic_bsz=False use_dynamic_bsz=True 说明
*.ppo_micro_batch_size_per_gpu 生效 忽略 每卡 micro-batch 的样本条数
*.log_prob_micro_batch_size_per_gpu 生效 忽略 compute_log_prob / ref_log_prob 用
actor.ppo_max_token_len_per_gpu 忽略 生效 actor 训练 micro-batch 的 token 上限
ref / rollout.log_prob_max_token_len_per_gpu 忽略 生效 前向算 logprob 时的 token 上限
critic.ppo_max_token_len_per_gpu 忽略 生效 critic 用 dynamic 时才生效
data.max_prompt_length / max_response_length 生效 生效 跟 dynamic_bsz 无关,是单条样本的硬上限
rollout.max_model_len / max_num_batched_tokens 生效 生效 vLLM 引擎参数,跟训练侧切分无关