从「按条等分」到「按 token 动态打包」——一次讲清 verl 训练侧两条 micro-batch 切分路径,以及为什么 dynamic 模式更容易踩到 NCCL hang。
verl 的训练侧把一个 mini-batch 切成 micro-batch 时有两条路径:默认按 micro_batch_size_per_gpu 等条数切,开了 use_dynamic_bsz=True 后改为按 max_token_len_per_gpu 等 token 数动态打包。后者用 Karmarkar-Karp 算法做 partition,并在 DP 维度用 all_reduce(MAX) 把 micro-batch 数量拉齐——这两步是它能跑通的关键,也是它最容易引发 NCCL hang 的根源。
verl 里 "batch size" 这个词在不同位置含义不同。在讨论怎么切之前,先看清数据是怎么从一个 step 流到一张 GPU 的。
_balance_batch 全局重排,让 chunk 后各 rank 总 token 接近use_dynamic_bsz=True)use_dynamic_bsz)。本文会分别拆解这两层的作用和局限。关键参数:ppo_micro_batch_size_per_gpu / log_prob_micro_batch_size_per_gpu。逻辑非常简单——按 样本条数 等分。
# verl/workers/engine/utils.py if use_dynamic_bsz: ... else: total_data_size = len(data) micro_batch_size_per_gpu = data["micro_batch_size_per_gpu"] assert total_data_size % (force_group_size * micro_batch_size_per_gpu) == 0 micro_batches = tu.chunk_tensordict( data, total_data_size // (micro_batch_size_per_gpu * force_group_size) )
特性
micro_batch_size_per_gpu 条样本,无论它们多长。关键参数:ppo_max_token_len_per_gpu。不再按条数切,而是给每张卡一个 token 预算,由算法去打包。
入口代码
# verl/workers/engine/utils.py — prepare_micro_batches() if use_dynamic_bsz: assert "max_token_len_per_gpu" in data.keys() max_token_len_per_gpu = data["max_token_len_per_gpu"] max_token_len = max_token_len_per_gpu * sp_size # SP 倍率 micro_batches, batch_idx_list = rearrange_micro_batches( data, max_token_len=max_token_len, dp_group=dp_group, same_micro_num_in_dp=same_micro_num_in_dp, # 关键 ... )
三步切分流程
步骤 1:算每张卡至少要切几个 micro-batch
total_seqlen = attention_mask.sum() num_micro_batches = ceil(total_seqlen / max_token_len)
步骤 2:DP 维度同步 micro-batch 数(这步是 dynamic 模式独有的关键)
if dist.is_initialized() and same_micro_num_in_dp and dp_group is not None: num_micro_batches = torch.tensor([num_micro_batches], device=...) dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) num_micro_batches = num_micro_batches.cpu().item()
步骤 3:用 Karmarkar-Karp 算法把样本按工作量平均分进 num_micro_batches 个桶
# 工作量估算:模拟 attention FLOPs ≈ 12·h²·L + 2·h·L²
workload = 24576 * seqlen + seqlen ** 2
partitions = karmarkar_karp(workload, k_partitions=num_micro_batches, equal_size=False)
第 03 节的 dynamic_bsz 只管 rank 内的 micro-batch 切分。那 rank 之间的总 token 数谁负责?答案是:在 dynamic_bsz 进各 rank 之前,verl 在 single controller 上先做了一次全局 KK 重排,叫 _balance_batch。由 trainer.balance_batch=True 默认开启。
# verl/trainer/ppo/ray_trainer.py — _balance_batch() global_seqlen_lst = batch.batch["attention_mask"].sum(-1) workload_lst = calculate_workload(global_seqlen_lst) dp_size = self._get_dp_size(self.actor_rollout_wg, "actor") # equal_size=True:每个 rank 拿到等条数样本 global_partition_lst = get_seqlen_balanced_partitions( workload_lst, k_partitions=dp_size, equal_size=True ) # 把 batch 按 KK 给的分组重排,再 chunk 出去 global_idx = torch.tensor([j for partition in global_partition_lst for j in partition]) batch.reorder(global_idx)
这层和 dynamic_bsz 的关键区别
equal_size=True:每个 rank 必须拿到等条数样本(PPO 后续 mini-batch 数对齐需要),KK 在这个硬约束下做平衡。12·h²·L + 2·h·L² 估算 attention FLOPs,让长样本被特别"重视"。同一组 16 条不同长度样本,按原顺序 chunk vs 经过 _balance_batch 重排后再 chunk。
equal_size=True 让 KK 不能完全自由,比如「单条特别长的样本」总会落到某个 rank 上、把那个 rank 的 token 总数推高。这是为什么 log 里 balanced_max - balanced_min 不会真的变成 0,只是从原始的 ~10% 极差压到 ~3-5%。
log 里 global_seqlen/min,max 是未重排时的统计;global_seqlen/balanced_min,balanced_max 是KK 重排后的统计。在自己的训练 metric 里直接对比这两组数字就能看 _balance_batch 在你的数据上有多大效果。
假设这张 GPU 拿到 8 条样本,token 长度依次为 。共 token。
按 batch 顺序每 2 条切一个 micro-batch。token 数完全失衡。
按 token 预算切,KK 算法做平衡,每桶 token 数接近预算。
桶里的色块代表一条样本,宽度正比于该样本 token 数。预算线表示 max_token_len_per_gpu;超过预算的桶会被高亮(动态模式不会出现,静态可能会 OOM)。
动态模式下,4 张卡各自算出来的 num_micro_batches 大概率不一样。verl 用 all_reduce(MAX) 把它们拉齐,否则 FSDP 集合通信对不上号会直接 hang。点击按钮看演示。
assert num_micro_batches ≤ num_groups,每个桶至少 1 条真实样本,没有空 batch。短 rank 后续 forward 循环里跑的是稀疏的真样本而非 padding。_balance_batch,由 trainer.balance_batch=True 默认开),会先重排样本让各 rank 总 token 数尽量接近(log 里 global_seqlen/balanced_min/max 就是它的产物)。但它有 equal_size=True 的硬约束(每 rank 等条数),且只平衡总和不平衡每个 micro-batch 的工作量。所以单条特别长的样本仍可能让某 rank 的某个 micro-batch 慢得多,触发 NCCL watchdog timeout(issue #5750 的根因之一)。遇到「我设了这个参数怎么没用」的时候回来看一眼。
| 参数 | use_dynamic_bsz=False | use_dynamic_bsz=True | 说明 |
|---|---|---|---|
| *.ppo_micro_batch_size_per_gpu | 生效 | 忽略 | 每卡 micro-batch 的样本条数 |
| *.log_prob_micro_batch_size_per_gpu | 生效 | 忽略 | compute_log_prob / ref_log_prob 用 |
| actor.ppo_max_token_len_per_gpu | 忽略 | 生效 | actor 训练 micro-batch 的 token 上限 |
| ref / rollout.log_prob_max_token_len_per_gpu | 忽略 | 生效 | 前向算 logprob 时的 token 上限 |
| critic.ppo_max_token_len_per_gpu | 忽略 | 生效 | critic 用 dynamic 时才生效 |
| data.max_prompt_length / max_response_length | 生效 | 生效 | 跟 dynamic_bsz 无关,是单条样本的硬上限 |
| rollout.max_model_len / max_num_batched_tokens | 生效 | 生效 | vLLM 引擎参数,跟训练侧切分无关 |