接口功能:计算
[object Object]训练场景下注意力的反向输出,支持Sliding Window Attention、Compressed Attention以及Sparse Compressed Attention。计算公式:
阶段一:根据不同cmp_ratio场景,对输入ori_kv与cmp_kv进行选择
- 当cmp_ratio = 1 (SWA):
- 当cmp_ratio = 4 (SCFA):
- else (CFA):
阶段二:计算P、dP、dS
阶段三:计算dQ, dKV, dSinks
每个算子分为,必须先调用“aclnnSparseFlashMlaGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnSparseFlashMlaGrad”接口执行计算。
[object Object]
[object Object]
确定性计算:
- aclnnSparseFlashMlaGrad默认非确定性实现,暂不支持通过aclrtCtxSetSysParamOpt开启确定性。
公共约束
- 入参为空的场景处理:
- query为空Tensor:直接返回。
- 入参为空的场景处理:
Mask
[object Object]规格约束
[object Object]
调用示例代码如下(以[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]为例),仅供参考,具体编译和执行过程请参考。
[object Object]