昇腾社区首页
中文
注册
开发者
下载

MLP层融合优化

基本原理

基于LLaMA模型结构进行分布式层面的算子融合,通过对MLP中up-proj层和gate层融合,可以在分布式训练场景下获得以下收益:

  • 通信:前向和反向各减少一次allreduce;
  • 计算:前向将2次矩阵乘融合成1次大矩阵乘;反向将4次矩阵乘融合成2次大矩阵乘,减少算子下发。
图1 MLP层融合示意

使用场景

LLaMA系列结构通用,当前在Atlas A2 训练系列产品上验证: LLaMA-16B性能提升3.3%,LLaMA-32B性能提升4.0%,LLaMA2-7B上提升9%。

操作步骤

使用AscendSpeed,添加 --mlp-layer-fusion标志。