NsaSelectedAttentionGrad
产品支持情况
功能说明
接口功能:根据topkIndices对key和value选取大小为selectedBlockSize的数据重排,接着进行训练场景下计算注意力的反向输出。
计算公式:
根据传入的topkIndice对keyIn和value选取数量为selectedBlockCount个大小为selectedBlockSize的数据重排,公式如下:
接着,进行注意力机制的反向计算,计算公式为:
函数原型
每个算子分为,必须先调用“aclnnNsaSelectedAttentionGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnNsaSelectedAttentionGrad”接口执行计算。
[object Object]
[object Object]
aclnnNsaSelectedAttentionGradGetWorkspaceSize
aclnnNsaSelectedAttentionGrad
约束说明
- 确定性计算:
- aclnnNsaSelectedAttentionGrad默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
- 输入query、key、value、attentionOut、attentionOutGrad的B(batchsize)必须相等。
- 输入key、value的N(numHead)必须一致。
- 输入query、attentionOut、attentionOutGrad的N(numHead)必须一致。
- 输入value、attentionOut、attentionOutGrad的D(HeadDim)必须一致。
- 输入query、key、value、attentionOut、attentionOutGrad的inputLayout必须一致。
- 关于数据shape的约束,以inputLayout的TND举例。其中:
- T1:取值范围为1~2M。T1表示query所有batch下S的和。
- T2:取值范围为1~2M。T2表示key、value所有batch下S的和。
- B:取值范围为1~2M。
- N1:取值范围为1~128。表示query的headNum。N1必须为N2的整数倍。
- N2:取值范围为1~128。表示key、value的headNum。
- G:取值范围为1~32。G = N1 / N2
- S:取值范围为1~128K。对于key、value的S 必须大于等于selectedBlockSize * selectedBlockCount, 且必须为selectedBlockSize的整数倍。
- D:取值范围为192或128,支持K和V的D(HeadDim)不相等。
- selectedBlockSize支持<=128且满足16的整数倍。
- selectBlockCount:支持[1~128]。 总计选择的大小
[object Object]< 128*64(8K) - Layout为TND时,每个Batch的S2都要大于总计选择的大小
[object Object]
- 关于softmaxMax与softmaxSum参数shape的约束:[T1, N1, 8]。
- 关于topkIndices参数shape的约束:[T1, N2, selectedBlockCount]。
调用示例
[object Object]