接口功能:aclnnBlockSparseAttention稀疏注意力反向计算,支持灵活的块级稀疏模式,通过BlockSparseMask指定每个Q块选择的KV块,实现高效的稀疏注意力计算。
计算公式: 稀疏块大小:,BlockSparseMask指定稀疏模式。
已知正向计算公式为:
为方便表达,以变量和表示计算公式:
则反向计算公式为:
BlockSparseAttentionGrad输入dout、 query、key、value, attentionOut的数据排布格式支持从多种维度排布解读,可通过qInputLayout和kvInputLayout传入。为了方便理解后续支持的具体排布格式(如 BNSD、TND 等),此处先对排布格式中各缩写字母所代表的维度含义进行统一说明:
- B:表示输入样本批量大小(Batch)
- T:B和S合轴紧密排列的长度(Total tokens)
- S:表示输入样本序列长度(Seq-Length)
- H:表示隐藏层的大小(Head-Size)
- N:表示多头数(Head-Num)
- D:表示隐藏层最小的单元尺寸,需满足D=H/N(Head-Dim)
当前支持的布局:
- qInputLayout: "TND" "BNSD"
- kvInputLayout: "TND" "BNSD"
每个算子分为,必须先调用"aclnnBlockSparseAttentionGradGetWorkspaceSize"接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用"aclnnBlockSparseAttentionGrad"接口执行计算。
[object Object]
[object Object]
- 参数说明:
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
- actualSeqLengthsOptional在qInputLayout为“TND”时必选;actualSeqLengthsKvOptional在kvInputLayout为“TND”时必选。
- 根据算子支持的输入 Layout,query 张量 Shape 中对应的 head 维度大小记为 N1,key 和 value 张量 Shape 中对应的 head 维度大小记为 N2。必须满足 N1 >= N2 且 N1 % N2 == 0。(例如:在 BNSD 布局下,N1 对应 query 的第 2 维,N2 对应 key/value 的第 2 维)
- headdim=128。
- 当前只支持 BNSD 和 MHA(N1==N2)。
[object Object]