昇腾社区首页
中文
注册
开发者
下载

aclnnMlaPrologV3WeightNz

产品支持情况

[object Object]undefined

功能说明

  • 功能更新:(相对于aclnnMlaPrologV2weightNz的差异)

    • 新增Query与Key的尺度矫正因子,分别对应qcQrScale(αq\alpha_q)与kcScale(αkv\alpha_{kv})。
    • 新增可选输入参数(例如actualSeqLenOptional、kNopeClipAlphaOptional、queryNormFlag、weightQuantMode、kvCacheQuantMode、queryQuantMode、ckvkrRepoMode、quantScaleRepoMode、tileSize、queryNormOptional和dequantScaleQNormOptional等),将cache_mode由必选改为可选。
    • 调整cacheIndex参数的名称与位置,对应当前的cacheIndexOptional。
  • 接口功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为五路:

    • 首先对输入xx乘以WDQW^{DQ}进行下采样和RmsNorm后分为两路,第一路乘以WUQW^{UQ}WUKW^{UK}经过两次上采样后,再乘以Query尺度矫正因子αq\alpha_q得到qNq^N;第二路乘以WQRW^{QR}后经过旋转位置编码(ROPE)得到qRq^R
    • 第三路是输入xx乘以WDKVW^{DKV}进行下采样和RmsNorm后,乘以Key尺度矫正因子αkv\alpha_{kv}传入Cache中得到kCk^C
    • 第四路是输入xx乘以WKRW^{KR}后经过旋转位置编码后传入另一个Cache中得到kRk^R
    • 第五路是输出qNq^N经过DynamicQuant后得到的量化参数。
    • 权重参数WeightDq、WeightUqQr和WeightDkvKr需要以NZ格式传入
  • 计算公式

    RmsNorm公式

    RmsNorm(x)=γxiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)} RMS(x)=1Ni=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon}

    Query的计算公式,包括下采样、RmsNorm和两次上采样

    cQ=αqRmsNorm(xWDQ)c^Q = \alpha_q\cdot\mathrm{RmsNorm}(x \cdot W^{DQ}) qC=cQWUQq^C = c^Q \cdot W^{UQ} qN=qCWUKq^N = q^C \cdot W^{UK}

    对Query进行ROPE旋转位置编码

    qR=ROPE(cQWQR)q^R = \mathrm{ROPE}(c^Q \cdot W^{QR})

    Key的计算公式,包括下采样和RmsNorm,将计算结果存入cache

    cKV=αkvRmsNorm(xWDKV)c^{KV} = \alpha_{kv}\cdot\mathrm{RmsNorm}(x \cdot W^{DKV}) kC=Cache(cKV)k^C = \mathrm{Cache}(c^{KV})

    对Key进行ROPE旋转位置编码,并将结果存入cache

    kR=Cache(ROPE(xWKR))k^R = \mathrm{Cache}(\mathrm{ROPE}(x \cdot W^{KR}))

    Dequant Scale Query Nope 计算公式

    dequantScaleQNope=RowMax(abs(qN))/127\mathrm{dequantScaleQNope} = {\mathrm{RowMax}(\mathrm{abs}(q^{N})) / 127} qN=round(qN/dequantScaleQNope)q^{N} = {\mathrm{round}(q^{N} / \mathrm{dequantScaleQNope})}

函数原型

每个算子分为,必须先调用“aclnnMlaPrologV3WeightNzGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPrologV3WeightNz”接口执行计算。

[object Object]
[object Object]

aclnnMlaPrologV3WeightNzGetWorkspaceSize

  • 参数说明

    [object Object]undefined
  • 返回值

    aclnnStatus:返回状态码,具体参见。[object Object] 第一段接口完成入参校验,出现以下场景时报错:

    [object Object]undefined

aclnnMlaPrologV3WeightNz

  • 参数说明

    [object Object]undefined
  • 返回值 aclnnStatus:返回状态码,具体参见

约束说明

  • 确定性计算:

    • aclnnMlaPrologV3WeightNz默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
  • shape 格式字段含义说明

    [object Object]undefined
  • shape约束

    • 若tokenX的维度采用BS合轴,即(T, He)
      • ropeSin和ropeCos的shape为(T, Dr)
      • cacheIndex的shape为(T)
      • dequantScaleXOptional的shape为(T, 1)
      • queryOut的shape为(T, N, Hckv)
      • queryRopeOut的shape为(T, N, Dr)
      • 全量化场景下,dequantScaleQNopeOutOptional的shape为(T, N, 1),其他场景下为(1)
    • 若tokenX的维度不采用BS合轴,即(B, S, He)
      • ropeSin和ropeCos的shape为(B, S, Dr)
      • cacheIndex的shape为(B, S)
      • dequantScaleXOptional的shape为(B*S, 1)
      • queryOut的shape为(B, S, N, Hckv)
      • queryRopeOut的shape为(B, S, N, Dr)
      • 全量化场景下,dequantScaleQNopeOutOptional的shape为(B*S, N, 1),其他场景下为(1)
    • B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
      • 如果B、S、T取值为0,则queryOut、queryRopeOut输出空Tensor,kvCacheRef、krCacheRef不做更新。
      • 如果Skv取值为0,则queryOut、queryRopeOut、dequantScaleQNopeOutOptional正常计算,kvCacheRef、krCacheRef不做更新,即输出空Tensor。
  • 特殊约束

    • per-tile量化模式下,ckvkrRepoMode和quantScaleRepoMode必须同时为1。
    • per-tile量化模式下,CacheMode只支持PA_BSND, BSND和TND。
    • 当ckvkrRepoMode值为1时,krCache必须为空Tensor(即shape的乘积为0)。
  • aclnnMlaPrologV3WeightNz接口支持场景:

    [object Object]
  • 在不同量化场景下,参数的dtype组合需要满足如下条件:

    [object Object]

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]