跳转至

深度解析: 分布式数据并行 (DDP)

本文档旨在为希望深入理解本框架中分布式数据并行(DDP)实现细节的用户提供背景知识.

核心概念

首先, 理解以下几个核心概念至关重要:

  • world_size (世界大小): 参与分布式训练的总进程数. 在我们的框架中, 它通常等于 num_nodes * gpus_per_node.

  • rank (进程排名): 分配给每个独立进程的唯一ID, 范围从 0world_size - 1. rank 0 的进程通常被视为主进程, 负责日志记录和保存检查点等任务.

  • process_group (进程组): 所有参与训练的进程的集合. torch.distributed.init_process_group 负责初始化这个组, 让进程之间可以相互通信.

  • backend (后端): 实现进程间通信的具体库. 最常用的两个是:

    • nccl: NVIDIA Collective Communications Library. 这是用于 GPU 分布式训练的推荐后端, 因为它经过了高度优化, 能提供最佳性能.
    • gloo: 用于 CPU 分布式训练的后端.

DDP 是如何工作的?

分布式数据并行(DDP)的目标是通过在多个 GPU(或机器)上并行处理数据来加速训练. 其工作流程可以概括为以下几个步骤:

  1. 模型复制: 在训练开始时, rank 0 进程的模型状态(权重和偏置)被广播到所有其他进程. 这确保了所有进程都从完全相同的模型开始.

  2. 数据分片: DistributedDataSampler 是实现这一步的关键. 它不是将整个数据集分发给每个进程, 而是确保每个进程在每个 epoch 中都得到数据集的一个互不重叠的子集(分片). 这就是“数据并行”中“数据”的含义.

  3. 独立的前向传播: 每个进程使用自己的数据分片, 独立地在其本地模型上执行前向传播, 并计算损失.

  4. 梯度同步 (All-Reduce): 这是 DDP 的核心魔法. 在每个进程计算出本地梯度后, DDP 会自动触发一个 All-Reduce 操作. 在这个操作中, 所有进程的梯度被相加平均, 然后结果被分发回每个进程. 最终, 每个进程都拥有了在所有数据上计算出的平均梯度.

  5. 相同的权重更新: 因为所有进程都从相同的模型开始, 并且接收到了完全相同的平均梯度, 所以当它们调用 optimizer.step() 时, 它们会以完全相同的方式更新自己的模型权重. 这保证了在每次迭代后, 所有进程上的模型权重保持同步.

关键代码解析

sampler.set_epoch(epoch)

  • 为什么需要它? DistributedDataSampler 需要确保在每个 epoch 中, 数据的分片方式都是不同的, 否则模型每个 epoch 都会看到完全相同的数据子集, 这会损害模型的泛化能力. 通过调用 sampler.set_epoch(epoch), 我们改变了 sampler 内部的随机种子, 从而保证了每个 epoch 都有一个新的、随机的数据排列和分片方式.
  • 调用时机: 必须在每个 epoch 开始时, 创建 DataLoader 之前调用.

find_unused_parameters=False

  • 这是什么?DDP(model, find_unused_parameters=False) 中, 这个参数告诉 DDP 不要去检查模型中哪些参数在前向传播中没有被使用.
  • 为什么设置为 False 检查未使用的参数会带来一些开销. 如果您的模型所有参数都在 forward 中被使用(这是绝大多数情况), 将此项设置为 False 可以略微提高性能. 如果您的模型中有一些参数只在特定的代码路径下被使用(例如, 在 if 语句中), 您可能需要将其设置为 True, 否则 DDP 在反向传播时可能会因为找不到这些参数的梯度而报错.

DistributedManager.reduce_mean(tensor)

  • 这是什么? 这是一个辅助函数, 用于在所有进程中计算一个张量的平均值. 例如, 在验证结束时, 每个进程都计算出了其本地数据分片上的平均损失. 为了得到全局的平均验证损失, 我们需要调用 reduce_mean. 它首先使用 dist.all_reduce 将所有进程的损失相加, 然后除以 world_size.

单节点 vs 多节点

  • 单节点多 GPU: 这是最常见的使用场景. 您只需要设置 gpus_per_node(或让它自动检测), 框架会自动处理 mp.spawn 和 DDP 的设置.

  • 多节点多 GPU: 这种场景下, 配置变得更加复杂. 您必须手动为每个节点设置正确的环境变量 (MASTER_ADDR, MASTER_PORT, NUM_NODES, NODE_RANK), 以便它们可以找到彼此并加入同一个进程组. 通常, 这由集群管理工具(如 SLURM)自动完成.