aclnnFusedFloydAttentionGrad
产品支持情况
功能说明
接口功能:训练场景下,FloydAttn相较于传统FA主要是计算qk/pv注意力时会额外将seq作为batch轴从而转换为batchMatmul
计算公式:
已知注意力的正向计算公式为:
则注意力的反向计算公式为:
函数原型
每个算子分为,必须先调用“aclnnFusedFloydAttentionGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnFusedFloydAttentionGrad”接口执行计算。
[object Object]
[object Object]
aclnnFusedFloydAttentionGradGetWorkspaceSize
aclnnFusedFloydAttentionGrad
约束说明
该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配
关于数据shape的约束,其中:
- B:取值范围为1~2K。
- H:取值范围为1~256。
- N:取值范围为16~1M且N%16==0。
- M:取值范围为128~1M且M%128==0。
- K:取值范围为128~1M且K%128==0。
- D:取值范围为16~128。
query与key1的第0/2/4轴需相同。
key1与value1 shape需相同。
key2与value2 shape需相同。
query与dy/attentionIn shape需相同。
softmaxMax与softmaxSum shape需相同。
调用示例
[object Object]