昇腾社区首页
中文
注册

Swap Attention

背景与挑战

在大模型训练过程中,使用重计算功能可以显著减少内存使用,但会延长训练时间,从而降低执行效率。

解决方案

新增Swap Attention功能,利用设备内存和CPU内存存储激活值,在反向传播梯度时从CPU内存预取激活值以减少重计算,充分利用H2D高带宽的优势,通过网络补充存储、增强计算,提升MFU,加速大模型训练。

图1 Swap Attention原理示意

使用场景

  • 优化性能

    在需要启用全重计算的场景中,可以通过开启--swap-attention和设置--recompute-num-layers N来替代全重计算,从而提升性能。

  • 内存节省

    对于不需要重计算的场景,仅启用--swap-attention,可以在几乎不损失性能的情况下节省内存,从而支持更大模型的配置。

使用方法

启用Swap Attention,需在训练脚本中加入以下参数配置:

--use-flash-attn     # 启用前提需开启flash attention融合算子
--swap-attention

可选参数--swap-modules:参数类型为string,默认值为"input_norm,self_attention,post_attention_norm",可根据模型自行配置module,在Mcore场景下默认仅预取self_attention module。

  • 仅开启预取功能:--swap-attention

    开启后,将对每一层的attention层的激活值进行预取,提高计算效率。

    图2 仅开启预取功能
  • 开启预取功能并且指定重计算层数:--swap-attention和--recompute-num-layers N

    开启后,将对每一层的attention层的激活值进行预取,并对前N层的全连接层进行重计算。

    图3 开启预取功能并且指定重计算层数
  • --recompute-num-layers N中的N层数指的是每一个pp stage的层数。N的取值应该小于等于num-layers/pipeline-model-parallel-size。
  • 暂不兼容自适应选择重计算特性。
  • 若出现性能严重劣化,可能是跨NUMA内存访问引起,可尝试通过进程绑核缓解,实现方法请参见绑核工具
  • --swap-attention暂不兼容LoRA微调。

使用效果

与完全重计算相比 ,有性能收益; 与不重计算相比,有内存收益。