开发者
资源
[object Object][object Object][object Object]undefined
[object Object]
  • ​接口功能​:aclnnBlockSparseAttention稀疏注意力反向计算,支持灵活的块级稀疏模式,通过BlockSparseMask指定每个Q块选择的KV块,实现高效的稀疏注意力计算。

  • ​计算公式​: 稀疏块大小:blockShapeX×blockShapeYblockShapeX×blockShapeY,BlockSparseMask指定稀疏模式。

    已知正向计算公式为:

    attentionOut=Softmax(Mask(scalequerykeysparseT,atten_mask))valuesparseattentionOut=Softmax(Mask(scale⋅query⋅key_{sparse}^{T}, atten\_mask))⋅value_{sparse}

    为方便表达,以变量SSPP表示计算公式:

    S=Mask(scalequerykeysparseT,atten_mask)S = Mask(scale⋅query⋅key_{sparse}^{T},atten\_mask) P=SoftMax(S)P = SoftMax(S) V=valuesparseV = value_{sparse} Out=PVOut = PV

    则反向计算公式为:

    softmax_grad=softmaxGrad(dOut,attentionOut)softmax\_grad = softmaxGrad(dOut, attentionOut) dP=dOutVTdP=dOut * V^T dS=P(dPsoftmax_grad)dS = P * (dP-softmax\_grad) dV=PTdOutdV=P^T * dOut dQ=(dSK)scaledQ=(dS*K)*scale dK=(dSTQ)scaledK=(dS^T*Q)*scale

BlockSparseAttentionGrad输入dout、 query、key、value, attentionOut的数据排布格式支持从多种维度排布解读,可通过qInputLayout和kvInputLayout传入。为了方便理解后续支持的具体排布格式(如 BNSD、TND 等),此处先对排布格式中各缩写字母所代表的维度含义进行统一说明:

  • B:表示输入样本批量大小(Batch)
  • T:B和S合轴紧密排列的长度(Total tokens)
  • S:表示输入样本序列长度(Seq-Length)
  • H:表示隐藏层的大小(Head-Size)
  • N:表示多头数(Head-Num)
  • D:表示隐藏层最小的单元尺寸,需满足D=H/N(Head-Dim)

当前支持的布局:

  • qInputLayout: "TND" "BNSD"
  • kvInputLayout: "TND" "BNSD"
[object Object]

每个算子分为,必须先调用"aclnnBlockSparseAttentionGradGetWorkspaceSize"接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用"aclnnBlockSparseAttentionGrad"接口执行计算。

[object Object]
[object Object]
[object Object]
  • 参数说明:
[object Object]
  • 返回值

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

[object Object]
  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • actualSeqLengthsOptional在qInputLayout为“TND”时必选;actualSeqLengthsKvOptional在kvInputLayout为“TND”时必选。
  • 根据算子支持的输入 Layout,query 张量 Shape 中对应的 head 维度大小记为 N1,key 和 value 张量 Shape 中对应的 head 维度大小记为 N2。必须满足 N1 >= N2 且 N1 % N2 == 0。(例如:在 BNSD 布局下,N1 对应 query 的第 2 维,N2 对应 key/value 的第 2 维)
  • headdim=128。
  • 当前只支持 BNSD 和 MHA(N1==N2)。
[object Object]

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]