开发者
资源
[object Object][object Object][object Object]undefined
[object Object]
  • 接口功能:BlockSparseAttention稀疏注意力计算,支持灵活的块级稀疏模式,通过BlockSparseMask指定每个Q块选择的KV块,实现高效的稀疏注意力计算。

  • 计算公式:稀疏块大小:blockShapeX×blockShapeYblockShapeX \times blockShapeY,selectIdx指定稀疏模式

    attentionOut=Softmax(scalequerykeysparseT+atten_mask)valuesparseattentionOut = Softmax(scale \cdot query \cdot key_{sparse}^T + atten\_mask) \cdot value_{sparse}

    BlockSparseAttention输入query、key、value的数据排布格式支持从多种维度排布解读,可通过qInputLayout和kvInputLayout传入。

    • B:表示输入样本批量大小(Batch)
    • T:B和S合轴紧密排列的长度(Total tokens)
    • S:表示输入样本序列长度(Seq-Length)
    • H:表示隐藏层的大小(Head-Size)
    • N:表示多头数(Head-Num)
    • D:表示隐藏层最小的单元尺寸,需满足D=H/N(Head-Dim)

    当前支持的布局:

    • qInputLayout: "TND" "BNSD"
    • kvInputLayout: "TND" "BNSD"
[object Object]

每个算子分为两段式接口,必须先调用"aclnnBlockSparseAttentionGetWorkspaceSize"接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用"aclnnBlockSparseAttention"接口执行计算。

[object Object]
[object Object]
[object Object]
  • 参数说明

    [object Object]
  • 返回值

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]
[object Object]
  • 参数说明

    [object Object]
  • 返回值

返回aclnnStatus状态码,具体参见

[object Object]
  • 确定性计算:
    • aclnnBlockSparseAttention默认确定性实现。
  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • qInputLayout当前仅支持"TND"和"BNSD"。
  • kvInputLayout当前仅支持"TND"和"BNSD"。
  • 输入query、key、value的数据类型必须一致,支持FLOAT16和BFLOAT16。
  • blockShapeOptional如果传入,则必须包含至少两个元素[blockShapeX, blockShapeY],且值必须大于0,blockShapeY必须为128的倍数。
  • blockSparseMaskOptional当前必须传入,且shape必须为[batch, headNum, ceilDiv(maxQS, blockShapeX), ceilDiv(maxKVS, blockShapeY)]。
  • attentionMaskOptional当前只支持传入nullptr。
  • actualSeqLengthsOptional在qInputLayout为“TND”时必选;actualSeqLengthsKvOptional在kvInputLayout为“TND”时必选。
  • blockTableOptional当前只支持传入nullptr,表示不开启PagedAttention特性。
  • innerPrecise必须为0(float32 softmax)或1(fp16 softmax),query输入为BFLOAT16时,只能配置为0。
  • qSeqlen和kvSeqlen不需要被blockShape整除,支持非对齐场景,实际分块数通过向上取整计算。
  • 输入query的headNum为N1,输入key和value的headNum为N2,则N1 >= N2 && N1 % N2 == 0。
  • maskType当前只支持输入0,表示不加mask。
[object Object]

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。

[object Object]