深度解析: 分布式数据并行 (DDP)
本文档旨在为希望深入理解本框架中分布式数据并行(DDP)实现细节的用户提供背景知识.
核心概念
首先, 理解以下几个核心概念至关重要:
-
world_size(世界大小): 参与分布式训练的总进程数. 在我们的框架中, 它通常等于num_nodes * gpus_per_node. -
rank(进程排名): 分配给每个独立进程的唯一ID, 范围从0到world_size - 1.rank 0的进程通常被视为主进程, 负责日志记录和保存检查点等任务. -
process_group(进程组): 所有参与训练的进程的集合.torch.distributed.init_process_group负责初始化这个组, 让进程之间可以相互通信. -
backend(后端): 实现进程间通信的具体库. 最常用的两个是:nccl: NVIDIA Collective Communications Library. 这是用于 GPU 分布式训练的推荐后端, 因为它经过了高度优化, 能提供最佳性能.gloo: 用于 CPU 分布式训练的后端.
DDP 是如何工作的?
分布式数据并行(DDP)的目标是通过在多个 GPU(或机器)上并行处理数据来加速训练. 其工作流程可以概括为以下几个步骤:
-
模型复制: 在训练开始时,
rank 0进程的模型状态(权重和偏置)被广播到所有其他进程. 这确保了所有进程都从完全相同的模型开始. -
数据分片:
DistributedDataSampler是实现这一步的关键. 它不是将整个数据集分发给每个进程, 而是确保每个进程在每个epoch中都得到数据集的一个互不重叠的子集(分片). 这就是“数据并行”中“数据”的含义. -
独立的前向传播: 每个进程使用自己的数据分片, 独立地在其本地模型上执行前向传播, 并计算损失.
-
梯度同步 (
All-Reduce): 这是 DDP 的核心魔法. 在每个进程计算出本地梯度后, DDP 会自动触发一个All-Reduce操作. 在这个操作中, 所有进程的梯度被相加并平均, 然后结果被分发回每个进程. 最终, 每个进程都拥有了在所有数据上计算出的平均梯度. -
相同的权重更新: 因为所有进程都从相同的模型开始, 并且接收到了完全相同的平均梯度, 所以当它们调用
optimizer.step()时, 它们会以完全相同的方式更新自己的模型权重. 这保证了在每次迭代后, 所有进程上的模型权重保持同步.
关键代码解析
sampler.set_epoch(epoch)
- 为什么需要它?
DistributedDataSampler需要确保在每个epoch中, 数据的分片方式都是不同的, 否则模型每个epoch都会看到完全相同的数据子集, 这会损害模型的泛化能力. 通过调用sampler.set_epoch(epoch), 我们改变了sampler内部的随机种子, 从而保证了每个epoch都有一个新的、随机的数据排列和分片方式. - 调用时机: 必须在每个
epoch开始时, 创建DataLoader之前调用.
find_unused_parameters=False
- 这是什么? 在
DDP(model, find_unused_parameters=False)中, 这个参数告诉 DDP 不要去检查模型中哪些参数在前向传播中没有被使用. - 为什么设置为
False? 检查未使用的参数会带来一些开销. 如果您的模型所有参数都在forward中被使用(这是绝大多数情况), 将此项设置为False可以略微提高性能. 如果您的模型中有一些参数只在特定的代码路径下被使用(例如, 在if语句中), 您可能需要将其设置为True, 否则 DDP 在反向传播时可能会因为找不到这些参数的梯度而报错.
DistributedManager.reduce_mean(tensor)
- 这是什么? 这是一个辅助函数, 用于在所有进程中计算一个张量的平均值. 例如, 在验证结束时, 每个进程都计算出了其本地数据分片上的平均损失. 为了得到全局的平均验证损失, 我们需要调用
reduce_mean. 它首先使用dist.all_reduce将所有进程的损失相加, 然后除以world_size.
单节点 vs 多节点
-
单节点多 GPU: 这是最常见的使用场景. 您只需要设置
gpus_per_node(或让它自动检测), 框架会自动处理mp.spawn和 DDP 的设置. -
多节点多 GPU: 这种场景下, 配置变得更加复杂. 您必须手动为每个节点设置正确的环境变量 (
MASTER_ADDR,MASTER_PORT,NUM_NODES,NODE_RANK), 以便它们可以找到彼此并加入同一个进程组. 通常, 这由集群管理工具(如 SLURM)自动完成.