昇腾社区首页
中文
注册
开发者
下载

aclnnSparseFlashAttentionGrad

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:根据topkIndices对key和value选取大小为selectedBlockSize的数据重排,接着进行训练场景下计算注意力的反向输出。

  • 计算公式:根据传入的topkIndice对keyIn和value选取数量为selectedBlockCount个大小为selectedBlockSize的数据重排,公式如下:

selectedKey=Gather(key,topkIndices[i]),0<=i<selectBlockCountselectedKey\text{ }=\text{ }Gather \left( key,topkIndices \left[ i \left] \left) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right. selectedValue=Gather(value,topkIndices[i]),0<=i<selectBlockCountselectedValue\text{ }=\text{ }Gather \left( value,topkIndices \left[ i \left] \left) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right. [object Object]

阶段1:根据矩阵乘法导数规则,计算dPdPdVdV:

[object Object]dPt,:=dOt,:@VTdP\mathop{{}}\nolimits_{{t,:}}=dO\mathop{{}}\nolimits_{{t,:}}\text{@}V\mathop{{}}\nolimits^{{T}} dV[u]=PTt,:@dOt,:dV \left[ u \left] =P\mathop{{}}\nolimits_{{T}}^{{t,:}}\text{@}dO\mathop{{}}\nolimits_{{t,:}}\right. \right. [object Object]

阶段2:计算dSdS:

[object Object]dSt,:=[Pt,:@(dPt,:FlashSoftmaxGrad(dO,O))]d\mathop{{S}}\nolimits_{{t,:}}= \left[ P\mathop{{}}\nolimits_{{t,:}}@ \left( dP\mathop{{}}\nolimits_{{t,:}}-FlashSoftmaxGrad \left( dO,O \left) \left) \right] \right. \right. \right. \right. [object Object]

阶段3:计算dQdQdKdK:

[object Object]dQt,:=dSt,:@K[u]:t,:/dk,:d\mathop{{Q}}\nolimits_{{t,:}}=d\mathop{{S}}\nolimits_{{t,:}}@K \left[ u \left] \mathop{{}}\nolimits_{{:t,:}}/\sqrt{{d\mathop{{}}\nolimits_{{k,:}}}}\right. \right. dK[u]:t,:=dSt,:tT@Q/dt,:dK \left[ u \left] \mathop{{}}\nolimits_{{:t,:}}=dS\mathop{{}}\nolimits_{{t,:t}}\mathop{{}}\nolimits^{{T}}\text{@}Q/\sqrt{{d\mathop{{}}\nolimits_{{t,:}}}}\right. \right.

函数原型

每个算子分为,必须先调用“aclnnSparseFlashAttentionGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnSparseFlashAttentionGrad”接口执行计算。

[object Object]
[object Object]

aclnnSparseFlashAttentionGradGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnSparseFlashAttentionGrad

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束说明

  • 确定性计算:

    • aclnnSparseFlashAttentionGrad默认非确定性实现,不支持通过aclrtCtxSetSysParamOpt开启确定性。
  • 公共约束

    • 入参为空的场景处理:
      • query为空Tensor:直接返回。
  • Mask

    [object Object]
  • 规格约束

    [object Object]

调用示例

调用示例代码如下(以[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]为例),仅供参考,具体编译和执行过程请参考

[object Object]