昇腾社区首页
中文
注册
开发者
下载

aclnnNsaCompressAttention

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:NSA中compress attention以及select topk索引计算。论文:

  • 计算公式:压缩block大小:ll,select block大小:ll',压缩stride大小:dd

Pcmp=Softmax(querykeyT)P_{cmp} = Softmax(query*key^T) \\ attentionOut=Softmax(atten_mask(scalequerykeyT,atten_mask)))valueattentionOut = Softmax(atten\_mask(scale*query*key^T, atten\_mask)))*value Pslc[j]=m=0l/d1n=0l/d1Pcmp[l/djmn],P_{slc}[j] = \sum_{m=0}^{l'/d-1}\sum_{n=0}^{l/d-1}P_{cmp} [l'/d*j-m-n], Pslc=h=1HPslchP_{slc'} = \sum_{h=1}^{H}P_{slc}^{h} Pslc=topk_mask(Pslc)P_{slc'} = topk\_mask(P_{slc'}) topkIndices=topk(Pslc)topkIndices = topk(P_{slc'})

NsaCompressAttention输入query、key、value的数据排布格式支持从多种维度排布解读,可通过inputLayout传入,当前仅支持TND。

  • B:表示输入样本批量大小(Batch)
  • T:B和S合轴紧密排列的长度
  • S:表示输入样本序列长度(Seq-Length)
  • H:表示隐藏层的大小(Head-Size)
  • N:表示多头数(Head-Num)
  • D:表示隐藏层最小的单元尺寸,需满足D=H/N(Head-Dim)

函数原型

每个算子分为,必须先调用“aclnnNsaCompressAttentionGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnNsaCompressAttention”接口执行计算。

[object Object]
[object Object]

aclnnNsaCompressAttentionGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnNsaCompressAttention

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束说明

  • 确定性计算:
    • aclnnNsaCompressAttention默认确定性实现。
  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • compressBlockSize、compressStride、selectBlockSize必须是16的整数倍,并且满足:compressBlockSize>=compressStride && selectBlockSize>=compressBlockSize && selectBlockSize%compressStride==0
  • actualSeqQLenOptional, actualCmpSeqKvLenOptional, actualSelSeqKvLenOptional需要是前缀和模式;且TND格式下必须传入。
  • 由于UB限制,CmpSkv需要满足以下约束:CmpSkv <= 14000
  • SelSkv = CeilDiv(CmpSkv, selectBlockSize // compressStride)
  • 输入query、key、value的约束如下:
    • 数据类型必须一致。
    • batchSize必须相等。
    • headDim必须满足:qD == kD && kD >= vD
    • inputLayout必须一致。
  • 输入query的headNum为N1,输入key和value的headNum为N2,则N1 >= N2 && N1 % N2 == 0
  • 设G = N1 / N2,G需要满足以下约束:G < 128 && 128 % G == 0
  • attenMask和topkMask的使用需符合论文描述。

调用示例

调用示例代码如下(以[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]为例),仅供参考,具体编译和执行过程请参考

[object Object]