昇腾社区首页
中文
注册
开发者
下载

NsaSelectedAttentionGrad

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:根据topkIndices对key和value选取大小为selectedBlockSize的数据重排,接着进行训练场景下计算注意力的反向输出。

  • 计算公式:

    根据传入的topkIndice对keyIn和value选取数量为selectedBlockCount个大小为selectedBlockSize的数据重排,公式如下:

    selectedKey=Gather(key,topkIndices[i]),0<=i<selectedBlockCountselectedValue=Gather(value,topkIndices[i]),0<=i<selectedBlockCountselectedKey = Gather(key, topkIndices[i]),0<=i<selectedBlockCount \\ selectedValue = Gather(value, topkIndices[i]),0<=i<selectedBlockCount

    接着,进行注意力机制的反向计算,计算公式为:

    V=PTdYV=P^TdY Q=((dS)K)dQ=\frac{((dS)*K)}{\sqrt{d}} K=((dS)TQ)dK=\frac{((dS)^T*Q)}{\sqrt{d}}

函数原型

每个算子分为,必须先调用“aclnnNsaSelectedAttentionGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnNsaSelectedAttentionGrad”接口执行计算。

[object Object]
[object Object]

aclnnNsaSelectedAttentionGradGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnNsaSelectedAttentionGrad

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束说明

  • 确定性计算:
    • aclnnNsaSelectedAttentionGrad默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • 输入query、key、value、attentionOut、attentionOutGrad的B(batchsize)必须相等。
  • 输入key、value的N(numHead)必须一致。
  • 输入query、attentionOut、attentionOutGrad的N(numHead)必须一致。
  • 输入value、attentionOut、attentionOutGrad的D(HeadDim)必须一致。
  • 输入query、key、value、attentionOut、attentionOutGrad的inputLayout必须一致。
  • 关于数据shape的约束,以inputLayout的TND举例。其中:
    • T1:取值范围为1~2M。T1表示query所有batch下S的和。
    • T2:取值范围为1~2M。T2表示key、value所有batch下S的和。
    • B:取值范围为1~2M。
    • N1:取值范围为1~128。表示query的headNum。N1必须为N2的整数倍。
    • N2:取值范围为1~128。表示key、value的headNum。
    • G:取值范围为1~32。G = N1 / N2
    • S:取值范围为1~128K。对于key、value的S 必须大于等于selectedBlockSize * selectedBlockCount, 且必须为selectedBlockSize的整数倍。
    • D:取值范围为192或128,支持K和V的D(HeadDim)不相等。
    • selectedBlockSize支持<=128且满足16的整数倍。
    • selectBlockCount:支持[1~128]。 总计选择的大小[object Object] < 128*64(8K)
    • Layout为TND时,每个Batch的S2都要大于总计选择的大小[object Object]
  • 关于softmaxMax与softmaxSum参数shape的约束:[T1, N1, 8]。
  • 关于topkIndices参数shape的约束:[T1, N2, selectedBlockCount]。

调用示例

调用示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]