MLP层融合优化

基本原理

基于LLaMA模型结构进行分布式层面的算子融合,通过对MLP中up-proj层和gate层融合,可以在分布式训练场景下获得以下收益:

图1 MLP层融合示意

使用场景

LLaMA系列结构通用,当前在Atlas A2 训练系列产品上验证: LLaMA-16B性能提升3.3%,LLaMA-32B性能提升4.0%,LLaMA2-7B上提升9%。

操作步骤

使用AscendSpeed,添加 --mlp-layer-fusion标志。