FlashAttention算子性能调优案例

案例介绍

本案例中的算子FlashAttentionScoreGrad,用于训练场景下计算注意力的反向输出,即FlashAttentionScore算子的反向计算。

已知注意力的正向计算公式为:

为方便表达,以变量S和P表示计算公式:

则注意力的反向计算公式为:

计算流程图如下:

图1 算子计算流程

按照FlashAttention反向计算流程的实现,简介整体计算流程如下。对本算子的算法感兴趣的用户可简单了解,无需重点关注。

  1. 重计算p,本步骤重计算了FlashAttention流程中的softmax结果p,计算结果保存在ub中。

  2. 计算dp,该计算包含matmul计算和dropout计算,matmul计算中,左矩阵为dy,右矩阵为转置后的value。

  3. 计算ds,本计算中,FlashSoftmaxGrad计算的入参为dy、正向输出attention_in,该结果与dp做减操作,最终的结果与p相乘得到结果ds。

  4. 计算dq,本计算将ds结果与key做matmul计算,并将结果与scale相乘得到结果dq。

  5. 计算dk,本计算将转置后的ds结果与query做matmul计算,并将结果与scale相乘得到结果dk。

  6. 计算dv,本计算将p的结果做drop计算,转置后与dy做matmul计算。

本案例的验证平台为 Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 ,以两个场景为例,第一个场景的输入维度信息为:B=1,N1=12,N2=12,S1=6144,S2=6144,D=128,causal场景,即atten_mask的形状为下三角,如图2。第二个场景的输入维度信息为:B=24,N1=5,N2=5,S1=9216,S2=9216,D=64,不带atten_mask和drop_mask输入。主要涉及的优化手段包括tiling基本块大小调整,核间负载均衡,CV流水并行,MTE2流水优化以及FixPipe流水优化等优化手段。

图2 causal场景atten_mask形状

获取性能数据

流水优化分析工具包括CAModel和Profiling工具,分别从两个方面分析:第一个是从Profiling工具生成的Profiling数据中分析各项流水的占比,第二个是从CAModel工具生成的打点图分析各流水并行情况。

分析主要瓶颈点

通过观察分析流水图和Profiling数据,结合优化经验来判断性能瓶颈点。在优化过程中不同阶段可能会出现不同的瓶颈点,需要不断优化以达到最佳性能。

设计优化方案

验证优化方案性能收益

总结

融合算子场景,可参考此优化。