昇腾社区首页
中文
注册

aclnnMlaPrologV2WeightNz

产品支持情况

产品 是否支持
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]
[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]
[object Object]Atlas 200I/500 A2 推理产品[object Object] ×
[object Object]Atlas 推理系列产品[object Object] ×
[object Object]Atlas 训练系列产品[object Object] ×

功能说明

  • 算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为五路;

    • 首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}WUKW^{UK}经过两次上采样后得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R
    • 第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后传入Cache中得到kCk^C
    • 第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R
    • 第五路是输出qNq^N经过DynamicQuant后得到的量化参数。
    • 权重参数WeightDq、WeightUqQr和WeightDkvKr需要以NZ格式传入
  • 计算公式

    RmsNorm公式

    RmsNorm(x)=γxiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)} RMS(x)=1Ni=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

    Query在计算公式,包括下采样,RmsNorm和两次上采样

    cQ=RmsNorm(xWDQ)c^Q = RmsNorm(x \cdot W^{DQ}) qC=cQWUQq^C = c^Q \cdot W^{UQ} qN=qCWUKq^N = q^C \cdot W^{UK}

    对Query的进行ROPE旋转位置编码

    qR=ROPE(cQWQR)q^R = ROPE(c^Q \cdot W^{QR})

    Key计算公式,包括下采样和RmsNorm,将计算结果存入cache

    cKV=RmsNorm(xWDKV)c^{KV} = RmsNorm(x \cdot W^{DKV}) kC=Cache(cKV)k^C = Cache(c^{KV})

    对Key进行ROPE旋转位置编码,并将结果存入cache

    kR=Cache(ROPE(xWKR))k^R = Cache(ROPE(x \cdot W^{KR}))

    Dequant Scale Query Nope 计算公式:

    dequantScaleQNope=RowMax(abs(qN))/127dequantScaleQNope = {RowMax(abs(q^{N})) / 127} qN=round(qN/dequantScaleQNope)q^{N} = {round(q^{N} / dequantScaleQNope)}

函数原型

每个算子分为undefined,必须先调用“aclnnMlaPrologV2WeightNzGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPrologV2WeightNz”接口执行计算。

  • aclnnStatus aclnnMlaPrologV2WeightNzGetWorkspaceSize(const aclTensor *tokenX, const aclTensor *weightDq, const aclTensor *weightUqQr, const aclTensor *weightUk, const aclTensor *weightDkvKr, const aclTensor *rmsnormGammaCq, const aclTensor *rmsnormGammaCkv, const aclTensor *ropeSin, const aclTensor *ropeCos, const aclTensor *cacheIndex, aclTensor *kvCacheRef, aclTensor *krCacheRef, const aclTensor *dequantScaleXOptional, const aclTensor *dequantScaleWDqOptional, const aclTensor *dequantScaleWUqQrOptional, const aclTensor *dequantScaleWDkvKrOptional, const aclTensor *quantScaleCkvOptional, const aclTensor *quantScaleCkrOptional, const aclTensor *smoothScalesCqOptional, double rmsnormEpsilonCq, double rmsnormEpsilonCkv, char *cacheModeOptional, const aclTensor *queryOut, const aclTensor *queryRopeOut, const aclTensor *dequantScaleQNopeOutOptional, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnMlaPrologV2WeightNz(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

aclnnMlaPrologV2WeightNzGetWorkspaceSize

  • 参数说明:

    • tokenX(aclTensor*,计算输入):表示输入的tensor,用于计算Query和Key的x,Device侧的aclTensor。shape支持2维和3维,格式为(T,He)和(B,S,He),dtype支持BFLOAT16和INT8,undefined支持ND格式。

    • weightDq(aclTensor*,计算输入):表示用于计算Query的下采样权重矩阵,对应公式中的WDQW^{DQ},Device侧的aclTensor。其shape支持2维,格式为(He,Hcq),dtype支持BFLOAT16和INT8,undefined支持FRACTAL_NZ格式。

    • weightUqQr(aclTensor*,计算输入):表示用于计算Query的上采样权重矩阵和Query的位置编码权重矩阵,对应公式中的WUQW^{UQ}WQRW^{QR},Device侧的aclTensor。其shape支持2维,格式为(Hcq,N*(D+Dr)),dtype支持BFLOAT16和INT8,undefined支持FRACTAL_NZ格式。

    • weightUk(aclTensor*,计算输入):表示用于计算Key的上采样权重,对应公式中的WUKW^{UK},Device侧的aclTensor。其shape支持3维,格式为(N,D,Hckv),dtype支持BFLOAT16,undefined支持ND格式

    • weightDkvKr(aclTensor*,计算输入):表示用于计算Key的下采样权重矩阵和Key的位置编码权重矩阵,对应公式中的WDKVW^{DKV}WKRW^{KR},Device侧的aclTensor。其shape支持2维,格式为(He,Hckv+Dr),dtype支持BFLOAT16和INT8,undefined支持FRACTAL_NZ格式。

    • rmsnormGammaCq(aclTensor*,计算输入):表示计算cQc^Q的RmsNorm公式中的γ\gamma参数,Device侧的aclTensor。其shape支持1维,格式为(Hcq),dtype支持BFLOAT16,undefined支持ND格式。

    • rmsnormGammaCkv(aclTensor*,计算输入):表示计算cKVc^{KV}的RmsNorm公式中的γ\gamma参数,Device侧的aclTensor。其shape支持1维,格式为(Hckv),dtype支持BFLOAT16,undefined支持ND格式。

    • ropeSin(aclTensor*,计算输入):表示用于计算旋转位置编码的正弦参数矩阵,Device侧的aclTensor。其shape支持2维和3维,格式为(T,Dr)和(B,S,Dr),dtype支持BFLOAT16,undefined支持ND格式。

    • ropeCos(aclTensor*,计算输入):表示用于计算旋转位置编码的余弦参数矩阵,Device侧的aclTensor。其shape支持2维和3维,格式为(T,Dr)和(B,S,Dr),dtype支持BFLOAT16,undefined支持ND格式。

    • cacheIndex(aclTensor*,计算输入):表示用于存储kvCache和krCache的索引,Device侧的aclTensor。其shape支持1维和2维,格式为(T)和(B,S),dtype支持INT64,undefined支持ND格式。

      • cacheIndex的取值范围为[0,BlockNum*BlockSize),当前不会对cacheIndex传入值的合法性进行校验,需用户自行保证。
    • kvCacheRef(aclTensor*,计算输入):表示用于cache索引的aclTensor。其shape支持4维,格式为(BlockNum,BlockSize,Nkv,Hckv),dtype支持BFLOAT16和INT8,undefined支持ND格式。计算结果原地更新,更新后结果对应公式中的kCk^C

    • krCacheRef(aclTensor*,计算输入):表示用于key位置编码的cache,Device侧的aclTensor。其shape支持4维,格式为(BlockNum,BlockSize,Nkv,Dr),dtype支持BFLOAT16和INT8,undefined支持ND格式。计算结果原地更新,更新后结果对应公式中的kRk^R

    • dequantScaleXOptional(aclTensor*,计算输入):表示用于当输入tokenX为INT8类型时,下采样后进行反量化操作时的参数,tokenX量化方式为pertoken。其shape支持2维,格式为(T, 1)和(B*S, 1),dtype支持FLOAT,undefined支持ND格式。

    • dequantScaleWDqOptional(aclTensor*,计算输入):表示用于当tokenX输入为INT8类型时,下采样后进行反量化操作时的参数,tokenX量化方式为perchannel。其shape支持2维,格式为(1, N*(D+Dr)),dtype支持FLOAT,undefined支持ND格式。

    • dequantScaleWUqQrOptional(aclTensor*,计算输入):表示用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数为per-channel,Device侧的aclTensor。其shape支持2维,格式为(1,N*(D+Dr)),dtype支持FLOAT,undefined支持ND格式。

    • dequantScaleWDkvKrOptional(aclTensor*,计算输入):表示用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数为per-channel,Device侧的aclTensor。其shape支持2维,格式为(1,Hckv+Dr),dtype支持FLOAT,undefined支持ND格式。

    • quantScaleCkvOptional(aclTensor*,计算输入):用于对输出到KVCache中的数据做量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Hckv),dtype支持FLOAT,undefined支持ND格式。

    • quantScaleCkrOptional(aclTensor*,计算输入):用于对输出到KRCache中的数据做量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Dr),dtype支持FLOAT,undefined支持ND格式。

    • smoothScalesCqOptional(aclTensor*,计算输入):用于对RmsNormCq输出做动态量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Hcq),dtype支持FLOAT,undefined支持ND格式。

    • rmsnormEpsilonCq(double,计算输入):表示计算cQc^Q的RmsNorm公式中的ϵ\epsilon参数,用户不特意指定时建议传入1e-05。

    • rmsnormEpsilonCkv(double,计算输入):表示计算cKVc^{KV}的RmsNorm公式中的ϵ\epsilon参数,用户不特意指定时建议传入1e-05。

    • cacheModeOptional(char*,计算输入):表示kvCache的模式,支持"PA_BSND","PA_NZ",用户不特意指定时建议传入"PA_BSND"。

    • queryOut(aclTensor*,计算输出):表示Query的输出tensor,对应公式中的qNq^N,Device侧的aclTensor。shape支持3维和4维,格式为(T,N,Hckv)和(B,S,N,Hckv),dtype支持BFLOAT16,undefined支持ND格式。

    • queryRopeOut(aclTensor*,计算输出):表示Query位置编码的输出tensor,对应公式中的qRq^R,Device侧的aclTensor。shape支持3维和4维,格式为(T,N,Dr)和(B,S,N,Dr),dtype支持BFLOAT16,undefined支持ND格式。

    • dequantScaleQNopeOutOptional(aclTensor*,计算输出):表示Query的输出tensor的量化参数。shape支持3维,全量化kv_cache量化场景下,其shape为(B*S, N, 1)和(T, N, 1),dtype支持FLOAT,undefined支持ND格式;其他场景下,不修改此出参,所以不做任何约束且允许输入空指针,返回与输入一致。

    • workspaceSize(uint64_t*,计算输出):返回需要在Device侧申请的workspace大小。

    • executor(aclOpExecutor**,计算输出):返回op执行器,包含了算子计算流程。

  • 返回值:

    aclnnStatus:返回状态码,具体参见undefined

    [object Object]

aclnnMlaPrologV2WeightNz

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMlaPrologV2WeightNzGetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的Stream。
  • 返回值:

    aclnnStatus:返回状态码,具体参见undefined

约束说明

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。

  • shape格式字段含义

    • B:Batch 表示输入样本批量大小,取值范围:0~65536
    • S:Seq-Length 表示输入样本序列长度,取值范围:0~16
    • He:Head-Size 表示隐藏层的大小,取值固定为:7168
    • Hcq:q低秩矩阵维度,取值固定为:1536
    • N:Head-Num 表示多头数,取值范围:1、2、4、8、16、32、64、128
    • Hckv:kv低秩矩阵维度,取值固定为:512
    • D:qk不含位置编码维度,取值固定为:128
    • Dr:qk位置编码维度,取值固定为:64
    • Nkv:kv的head数,取值固定为:1
    • BlockNum:PagedAttention场景下的块数,取值为计算B*Skv/BlockSize的值后再向上取整,其中Skv表示kv的序列长度,该值允许取0
    • BlockSize:PagedAttention场景下的块大小,取值范围:16、128
    • T:BS合轴后的大小,取值范围:0~1048576
  • shape约束

    • 若tokenX的维度采用BS合轴,即(T, He)
      • ropeSin和ropeCos的shape为(T, Dr)
      • cacheIndex的shape为(T,)
      • dequantScaleXOptional的shape为(T, 1)
      • queryOut的shape为(T, N, Hckv)
      • queryRopeOut的shape为(T, N, Dr)
      • 全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为(1)
    • 若tokenX的维度不采用BS合轴,即(B, S, He)
      • ropeSin和ropeCos的shape为(B, S, Dr)
      • cacheIndex的shape为(B, S)
      • dequantScaleXOptional的shape为(B*S, 1)
      • queryOut的shape为(B, S, N, Hckv)
      • queryRopeOut的shape为(B, S, N, Dr)
      • 全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为(1)
    • B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
      • 如果B、S、T取值为0,则queryOut、queryRopeOut输出空Tensor,kvCacheRef、krCacheRef不做更新。
      • 如果Skv取值为0,则queryOut、queryRopeOut、dequantScaleQNopeOutOptional正常计算,kvCacheRef、krCacheRef不做更新,即输出空Tensor。
  • 当前aclnnMlaPrologV2WeightNz接口支持以下场景:

    [object Object][object Object]

    在不同量化场景下,参数的dtype和shape组合需要满足如下条件:

    [object Object]

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考undefined

[object Object]