昇腾社区首页
中文
注册
开发者
下载

aclnnMlaPreprocess

产品支持情况

产品 是否支持
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]
[object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object]
[object Object]Atlas 200I/500 A2 推理产品[object Object] ×
[object Object]Atlas 推理系列产品[object Object] ×
[object Object]Atlas 训练系列产品[object Object] ×

功能说明

  • 算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程如下:

    • 首先对输入xx RmsNormQuant后乘以WDQKVW^{DQKV}进行下采样后分为通路1和通路2。
    • 通路1做RmsNormQuant后乘以WUQW^{UQ}后再分为通路3和通路4。
    • 通路3后乘以WukW^{uk}后输出qNq^N
    • 通路4后经过旋转位置编码后输出qRq^R
    • 通路2拆分为通路5和通路6。
    • 通路5经过RmsNorm后传入Cache中得到kNk^N
    • 通路6经过旋转位置编码后传入另一个Cache中得到kRk^R
  • 计算公式

    RmsNormQuant公式

    RMS(x)=1Ni=1Nxi2+ϵ\text{RMS}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^{N} x_i^2 + \epsilon} RmsNorm(x)=γxiRMS(x)\text{RmsNorm}(x) = \gamma \cdot \frac{x_i}{\text{RMS}(x)} RmsNormQuant(x)=(RmsNorm(x)+bias)deqScaleRmsNormQuant(x) = ({RmsNorm}(x) + bias) * deqScale

    Query计算公式,包括W^{DQKV}矩阵乘、W^{UK}矩阵乘、RmsNormQuant和RoPE旋转位置编码处理

    qN=RmsNormQuant(x)WDQKVWUKq^N = RmsNormQuant(x) \cdot W^{DQKV} \cdot W^{UK} qR=ROPE(xQ)q^R = ROPE(x^Q)

    Key计算公式,包括RmsNorm和RoPE,将计算结果存入cache

    kN=Cache(RmsNorm(RmsNormQuant(x)))k^N = Cache({RmsNorm}(RmsNormQuant(x))) kR=Cache(ROPE(RmsNormQuant(x)))k^R = Cache(ROPE(RmsNormQuant(x)))

函数原型

每个算子分为,必须先调用“aclnnMlaPreprocessGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPreprocess”接口执行计算。

  • aclnnStatus aclnnMlaPreprocessGetWorkspaceSize(const aclTensor *input, const aclTensor *gamma0, const aclTensor *beta0, const aclTensor *quantScale0, const aclTensor *quantOffset0,const aclTensor *wdqkv, const aclTensor *deScale0, const aclTensor *bias0, const aclTensor *gamma1, const aclTensor *beta1, const aclTensor *quantScale1, const aclTensor *quantOffset1, const aclTensor *wuq, const aclTensor *deScale1, const aclTensor *bias1, const aclTensor *gamma2, const aclTensor *cos, const aclTensor *sin, const aclTensor *wuk, const aclTensor *kvCache, const aclTensor *kvCacheRope, const aclTensor *slotmapping, const aclTensor *ctkvScale, const aclTensor *qNopeScale, int64_t wdqDim, int64_t qRopeDim, int64_t kRopeDim, float epsilon, int64_t qRotaryCoeff, int64_t kRotaryCoeff, bool transposeWdq, bool transposeWuq, bool transposeWuk, int64_t cacheMode, int64_t quantMode, bool doRmsNorm, int64_t wdkvSplitCount, aclTensor *qOut, aclTensor *kvCacheOut, aclTensor *qRopeOut, aclTensor *krCacheOut, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnMlaPreprocess(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

aclnnMlaPreprocessGetWorkspaceSize

  • 参数说明:

    • input(aclTensor*,计算输入):Device侧的aclTensor,用于计算Query和Key的x,shape为[tokenNum,hiddenSize],dtype支持FLOAT16和BFLOAT16,支持ND格式。

    • gamma0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm计算中的γ参数,shape为[hiddenSize],dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • beta0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm计算中的β参数,shape为[hiddenSize],dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • quantScale0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm公式中量化缩放的参数,shape为[1],dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • quantOffset0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm公式中的量化偏移参数,shape为[1],dtype支持INT8,支持NZ格式。

    • wdqkv(aclTensor*,计算输入):Device侧的aclTensor,与输入首次做矩阵乘的降维矩阵,shape为[2112,hiddenSize],dtype支持INT8和BFLOAT16,支持ND格式。

    • deScale0(aclTensor*,计算输入):Device侧的aclTensor,输入首次做矩阵乘的降维矩阵中的系数,shape为[2112],dtype支持INT64和FLOAT,与input的dtype对应,input输入dtype为FLOAT16支持INT64,输入BFLOAT16时支持FLOAT。支持ND格式。

    • bias0(aclTensor*,计算输入):Device侧的aclTensor,输入首次做矩阵乘的降维矩阵中的系数,shape为[2112],dtype支持INT32,支持ND格式。支持传入空tensor,quantMode为1、3时不传入。

    • gamma1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm计算中的γ参数,shape为[1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • beta1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm计算中的β参数,shape为[1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • quantScale1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm公式中量化缩放的参数,shape为[1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。仅在quantMode为0时传入。

    • quantOffset1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm公式中的量化偏移参数,shape为[1]。dtype支持INT8,支持ND格式。仅在quantMode为0时传入。

    • wuq(aclTensor*,计算输入):Device侧的aclTensor,权重矩阵,shape为[headNum * 192,1536]。dtype支持INT8和BFLOAT16,支持NZ格式。

    • deScale1(aclTensor*,计算输入):Device侧的aclTensor,参与wuq矩阵乘的系数,shape为[headNum*192,1536]。dtype支持INT64和FLOAT,支持ND格式。input输入dtype为FLOAT16支持INT64,输入BFLOAT16时支持FLOAT。

    • bias1(aclTensor*,计算输入):Device侧的aclTensor,参与wuq矩阵乘的系数,shape为[[headNum*192]]。dtype支持INT32,支持NZ格式。quantMode为1、3时不传入。

    • gamma2(aclTensor*,计算输入):Device侧的aclTensor,参与RmsNormAndreshapeAndCache计算的γ参数,shape为[512]。dtype支持BLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • cos(aclTensor*,计算输入):Device侧的aclTensor,表示用于计算旋转位置编码的正弦参数矩阵,shape为[tokenNum,64]。dtype支持INT8,支持NZ格式。

    • sin(aclTensor*,计算输入):Device侧的aclTensor,表示用于计算旋转位置编码的余弦参数矩阵,shape为[tokenNum,64]。dtype支持INT8,支持NZ格式。

    • wuk(aclTensor*,计算输入):Device侧的aclTensor,表示计算Key的上采样权重,shape为[headNum * 192, 1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND/NZ格式。ND格式时的shape为[headNum,128,512],NZ格式时的shape为[headNum,32,128,16]。

    • kvCache(aclTensor*,计算输入):Device侧的aclTensor,与输出的kvCacheOut为同一tensor,输入格式随cacheMode变化。

      • cacheMode为0:shape为[blockNum,blockSize,1,576],dtype与input保持一致,为ND。
      • cacheMode为1:shape为[blockNum,blockSize,1,512],tensor的shape为拆分情况,dtype与input保持一致,为ND。
      • cacheMode为2:shape为[blockNum,headNum*512/32,block_size,32],dtype为int8,为NZ。
      • cacheMode为3:shape为[blockNum,headNum*512/16,block_size,16],dtype与input保持一致,为NZ。
    • kvCacheRope(aclTensor*,计算输入)Device侧的aclTensor,可选参数,支出传入空指针。与输出的krCacheOut为同一tensor,输入格式随cacheMode变化。

      • cacheMode为0:不传入。
      • cacheMode为1:shape为[blockNum,blockSize,1,64],dtype与input保持一致,为ND。
      • cacheMode为2或3:shape为[blockNum, headNum*64 / 16 ,block_size, 16],dtype与input保持一致,为NZ。
    • slotmapping(aclTensor*,计算输入):Device侧的aclTensor,表示用于存储kv_cache和kr_cache的索引,shape为[tokenNum]。dtype支持INT32,支持ND格式。

    • ctkvScale(aclTensor*,计算输入):Device侧的aclTensor,输出量化处理中参与计算的系数,仅在cacheMode为2时传入,shape为[1]。dtype支持BLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • qNopeScale(aclTensor*,计算输入):Device侧的aclTensor,输出量化处理中参与计算的系数,仅在cacheMode为2时传入,shape为[headNum]。dtype支持BLOAT16和BFLOAT16,与input保持一致,支持ND格式。

    • wdqDim(int64_t,计算输入):表示经过matmul后拆分的dim大小。预留参数,目前只支持1536。

    • qRopeDim(int64_t,计算输入):表示q传入RoPE的dim大小。预留参数,目前只支持64。

    • kRopeDim(int64_t,计算输入):表示k传入RoPE的dim大小。预留参数,目前只支持64。

    • epsilon(float,计算输入):表示加在分母上防止除0。

    • qRotaryCoeff(int64_t,计算输入):表示q旋转系数。预留参数,目前只支持2。

    • kRotaryCoeff(int64_t,计算输入):表示k旋转系数。预留参数,目前只支持2。

    • transposeWdq(bool,计算输入):表示wdq是否转置。预留参数,目前只支持true。

    • transposeWuq(bool,计算输入):表示wuq是否转置。预留参数,目前只支持true。

    • transposeWuk(bool,计算输入):表示wuk是否转置。预留参数,目前只支持true。

    • cacheMode(int64_t,计算输入):表示指定cache的类型,取值范围[0, 3]。

      • 0:kcache和q均经过拼接后输出。
      • 1:输出的kvCacheOut拆分为kvCacheOut和krCacheOut,qOut拆分为qOut和qRopeOut。
      • 2:krope和ctkv转为NZ格式输出,ctkv和qnope经过per_head静态对称量化为int8类型。
      • 3:krope和ctkv转为NZ格式输出。
    • quantMode(int64_t,计算输入):表示指定RmsNorm量化的类型,取值范围[0, 3]。

      • 0:per_tensor静态非对称量化,默认量化类型。
      • 1:per_token动态对称量化,未实现。
      • 2:per_token动态非对称量化,未实现。
      • 3:不量化,浮点输出,未实现。
    • doRmsNorm(bool,计算输入):表示是否对input输入进行RmsNormQuant操作,false表示不操作,true表示进行操作。预留参数,目前只支持true。

    • wdkvSplitCount(int64_t,计算输入):表示指定wdkv拆分的个数,支持[1-3],分别表示不拆分、拆分为2个、拆分为3个降维矩阵。预留参数,目前只支持1。

    • qOut(aclTensor*,计算输出):计算输出,表示Query的输出tensor,对应计算流图中右侧经过NOPE和矩阵乘后的输出,shape和dtype随cacheMode变化。

      • cacheMode为0:shape为[tokenNum, headNum, 576],dtype与input一致,为ND。
      • cacheMode为1或3:shape为[tokenNum, headNum, 512],dtype与input一致,为ND。
      • cacheMode为2:shape为[tokenNum, headNum, 512],dtype为INT8,为ND格式。
    • kvCacheOut(aclTensor*,计算输出):计算输出,表示Key经过ReshapeAndCache后的输出,shape和dtype随cacheMode变化。

      • cacheMode为0:shape为[blockNum, blockSize, 1, 576], dtype与input一致,为ND。
      • cacheMode为1:shape为[blockNum, blockSize, 1, 512], dtype与input一致,为ND。
      • cacheMode为2:shape为[blockNum, headNum*512/32, block_size, 32],dtype为INT8,为NZ。
      • cacheMode为3:shape为[blockNum, headNum*512/16, block_size, 16],dtype与input一致,为NZ。
    • qRopeOut(aclTensor*,计算输出):计算输出,表示Query经过旋转编程后的输出,shape和dtype随cacheMode变化。

      • cacheMode为0:不输出。
      • cacheMode为1或3:shape为[tokenNum, headNum, 64],dtype与input一致,为ND。
      • cacheMode为2:shape为[tokenNum, headNum, 64],dtype与input一致,为ND。
    • krCacheOut(aclTensor*,计算输出):表示Key经过RoPE和ReshapeAndCache后的输出,shape和随cacheMode变化,

      • cacheMode为0:不输出。
      • cacheMode为1:shape为[blockNum, blockSize, 1, 64],dtype与input一致,为ND。
      • cacheMode为2或3:shape为[blockNum, headNum*64 / 16 ,block_size, 16],dtype与input一致,为NZ。
    • workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。

    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。

  • 返回值:

    aclnnStatus:返回状态码,具体参见

    [object Object]

aclnnMlaPreprocess

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMlaPreprocessGetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的Stream。
  • 返回值:

    aclnnStatus:返回状态码,具体参见

约束说明

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
  • shape格式字段含义及约束
    • tokenNum:tokenNum 表示输入样本批量大小,取值范围:0~256
    • hiddenSize:hiddenSize 表示隐藏层的大小,取值固定为:2048-10240,为256的倍数
    • headNum:表示多头数,取值范围:16、32、64、128
    • blockNum:PagedAttention场景下的块数,取值范围:192
    • blockSize:PagedAttention场景下的块大小,取值范围:128
    • 当wdqkv和wuq的数据类型为BFLOAT16时,输入input也需要为BFLOAT16,且hiddenSize只支持6144

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]