昇腾社区首页
中文
注册
开发者
下载

aclnnNsaCompress

产品支持情况

[object Object]undefined

功能说明

  • 接口功能:训练场景下,使用NSA Compress算法减轻long-context的注意力计算,实现在KV序列维度进行压缩。

  • 计算公式:

    Nsa Compress正向计算公式如下:

K~tcmp=fKcmp(k:t)={φ(kid+1:id+l)0itld}\tilde{K}_t^{\text{cmp}} = f_K^{\text{cmp}}(k_{:t}) = \left\{ \varphi(k_{id+1:id+l}) \bigg| 0 \leq i \leq \left\lfloor \frac{t-l}{d} \right\rfloor \right\}

函数原型

每个算子分为,必须先调用“aclnnNsaCompressGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnNsaCompress”接口执行计算。

[object Object]
[object Object]

aclnnNsaCompressGetWorkspaceSize

  • 参数说明:

    [object Object]
    • input数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N; 其中T是B和S合轴紧密排列的数据(每个batch的actSeqLen)、B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnNsaCompress

  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

约束说明

  • 确定性计算:
    • aclnnNsaCompress默认确定性实现。
  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • input和weight需要满足broadcast关系,input.shape[1]=weight.shape[1],不支持input、weight为空输入。
  • actSeqLenType目前仅支持取值0,即actSeqLenOptional需要是前缀和模式。
  • layoutOptional目前仅支持TND,此时input.shape[0]必须等于actSeqLenOptional[-1]。
  • input.shape[1]=weight.shape[1],需要小于等于128。
  • input.shape[2]必须是16的倍数,上限256。
  • weight.shape[0]=compressBlockSize,必须是16的倍数,上限128。
  • compressStride必须是16的整数倍,并且compressBlockSize>=compressStride。

调用示例

调用示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]