大规模并行训练的混合精度选择
大规模分布式并行训练中使用半精度浮点数,一般有两种选择:float16或者bfloat16。这两种格式的区别见半精度浮点数。float16相对于bfloat16有更高的精度,但是表示范围更小。混合精度训练首先需要把模型中适合的参数转移到半精度浮点类型。
如果混合精度训练中选择float16,为了避免表示范围小引起的浮点上溢和下溢,混合精度要结合动态Loss缩放机制,以下是采用float16的混合精度训练典型流程概述:
- 保留一份FP32格式的权重主备份,同时优化器状态也应以FP32格式存储。
- 将Loss缩放因子S初始化为一个较大的值。
- 对每一个训练step:
- 将权重复制一份到FP16格式。
- 使FP16格式的权重和激活值进行前向传播。
- 将最终的Loss乘以缩放因子S。
- 使用FP16格式进行后向传播,包括权重、激活值及其对应的梯度。
- 若检测到权重梯度中出现Inf或NaN:
- 减小S值。
- 跳过当前权重更新步骤,重新开始下一个训练步骤。
- 将权重梯度乘以1/S。
- 梯度累积或者梯度累积足够步后使用FP32更新主权重。
- 如果之前N步都没有看到Inf或者NaN,增加S值。
- 在上述过程中,有几处计算必须要以FP32完成。比如主权重的更新,因为累加能够导致精度误差积累,所以必须要以FP32计算。缩放因子必须是FP32类型,甚至1/S的计算要将S转成双精度数求倒数再转回到FP32。
使用BF16格式的半精度数时,因为BF16有更大的表示范围,所以一般无需使用Loss缩放机制。但是BF16数值精度比FP16更差,所以在步骤三的第七点做梯度累积的时候需要使用FP32,否则有可能会因为梯度累积误差导致模型不收敛。另外BF16比FP16多15%的运行时内存,主要原因在于梯度累积时需要转FP32。
PyTorch提供了自动混合精度(AMP)的机制,AMP按需自动调整张量的数据类型(dtype)。例如在AMP autocast上下文时,矩阵乘法matmul的输入张量会被自动转化为半精度浮点类型。AMP也提供了GradScaler,通过自动调整Loss的缩放来防止梯度的下溢和上溢。PyTorch的AMP优化级别使用apex.amp的O1级,这意味着PyTorch AMP使用黑白名单自动决定使用FP16、BF16还是FP32进行计算,但还有一些特定模型相关的精度敏感的运算并不在AMP的自动upcast名单中,需要开发者手动干预。所以使用AMP时,开发者需要对AMP的黑白名单有一定了解。
父主题: 混合精度配置选择