自适应选择重计算

基本原理

重计算是在模型训练的forward计算后,不保存输出,在backward计算需要的时候再重新计算一遍所需的forward输出值,以此降低内存峰值,避免训练过程中发生OOM。

由于forward在backward都重计算,导致模型整体的训练性能下降,而且内存也没有充分利用。因此,通过选择部分forward进行重计算,可以很好的平衡训练性能和内存使用。

传统的选择重计算需要根据模型定制,流程比较复杂,而且得到一个比较理想的策略也比较困难。自适应选择重计算仅仅只需要通过调整训练内存阈值,自适应选择重计算的forward。如果发生OOM,可以通过调小训练内存阈值的方式重启训练,直到找到一个合适的选择重计算策略。

图1 自适应选择重计算训练流程

使用场景

适用于基于transformer模型的大模型训练。实验在LLAMA2-7B和LLAMA-65B上测试性能提升5%以上。

操作步骤

使用AscendSpeed启用自适应选择重计算,请使用--auto-recompute-device-size标志指定自适应选择重计算策略的内存大小,一般值越大越容易发生OOM。请注意,如果要使用--auto-recompute-device-size标志,请删除标志--checkpoint-activations。如果中途退出训练可以使用Ctrl+C退出,或者kill -9 PID结束进程。

自适应选择重计算根据profiling前N步的训练内存信息进行策略选择,可以通过使用--auto-recompute-profiling-step标志设置停止profiling的步数。 默认在第10步停止profiling,最小设置为5步,建议在训练内存平稳后停止profiling,这样可以获得更佳的选择重计算策略。