接口功能:aclnnBlockSparseAttention稀疏注意力反向计算,支持灵活的块级稀疏模式,通过BlockSparseMask指定每个Q块选择的KV块,实现高效的稀疏注意力计算。
计算公式: 稀疏块大小:,BlockSparseMask指定稀疏模式。
已知正向计算公式为:
为方便表达,以变量和表示计算公式:
则反向计算公式为:
BlockSparseAttentionGrad输入dout、 query、key、value, attentionOut的数据排布格式支持从多种维度排布解读,可通过qInputLayout和kvInputLayout传入。为了方便理解后续支持的具体排布格式(如BNSD、TND等),此处先对排布格式中各缩写字母所代表的维度含义进行统一说明:
- B:表示输入样本批量大小(Batch)
- T:B和S合轴紧密排列的长度(Total tokens)
- S:表示输入样本序列长度(Seq-Length)
- H:表示隐藏层的大小(Head-Size)
- N:表示多头数(Head-Num)
- D:表示隐藏层最小的单元尺寸,需满足D=H/N(Head-Dim)
当前支持的布局:
- qInputLayout: "TND" "BNSD"
- kvInputLayout: "TND" "BNSD"
每个算子分为,必须先调用"aclnnBlockSparseAttentionGradGetWorkspaceSize"接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用"aclnnBlockSparseAttentionGrad"接口执行计算。
[object Object]
[object Object]
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
- actualSeqLengthsOptional在qInputLayout为“TND”时必选;actualSeqLengthsKvOptional在kvInputLayout为“TND”时必选。
- 根据算子支持的输入Layout,query张量Shape中对应的head维度大小记为N1,key和value张量Shape中对应的head维度大小记为N2。必须满足N1 >= N2 且N1 % N2 == 0。(例如:在BNSD布局下,N1 对应query的第2 维,N2 对应key/value的第2 维)
- headdim=128。
[object Object]