混合精度训练中,需要保存参数副本、梯度副本、优化器状态等多种状态张量,占据了大量的静态内存(16N,N为参数量),而实际参与前反向计算的参数和梯度(4N,N为参数量)相比之下占比很小,优化以上状态张量可以带来极大的显存收益。本算法希望通过深入分析每部分状态张量的实际使用实现机制的显存复用,最终得到一个集成多个算法模块的多级优化器内存优化方案。当前特性包括:内存复用O1——梯度副本去冗余。
优势:完全等价、支持多种优化器、性能无损。
算法原理:将原本需要持久保存的FP32梯度副本的静态内存,复用FP16梯度的内存,在需要时通过Foreach+Cast操作转换成FP32的形式,可节省4N的空间。
所有优化器适用,当前在Atlas A2 训练系列产品上验证:LLaMA-7B/LLaMA-13B/LLaMA2-7B/Bloom-7B/LLaMA-32B模型上,训练内存端到端平均缩减13.59%,不影响精度,性能无劣化。
使用AscendSpeed,开启--release-fp32-grad标志。