aclnnMlaPreprocess
产品支持情况
| 产品 | 是否支持 | 
|---|---|
| [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object] | √ | 
| [object Object]Atlas A2 训练系列产品/Atlas A2 推理系列产品[object Object] | √ | 
| [object Object]Atlas 200I/500 A2 推理产品[object Object] | × | 
| [object Object]Atlas 推理系列产品[object Object] | × | 
| [object Object]Atlas 训练系列产品[object Object] | × | 
功能说明
算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程如下:
- 首先对输入 RmsNormQuant后乘以进行下采样后分为通路1和通路2。
 - 通路1做RmsNormQuant后乘以后再分为通路3和通路4。
 - 通路3后乘以后输出。
 - 通路4后经过旋转位置编码后输出。
 - 通路2拆分为通路5和通路6。
 - 通路5经过RmsNorm后传入Cache中得到。
 - 通路6经过旋转位置编码后传入另一个Cache中得到。
 
计算公式:
RmsNormQuant公式
Query计算公式,包括W^{DQKV}矩阵乘、W^{UK}矩阵乘、RmsNormQuant和RoPE旋转位置编码处理
Key计算公式,包括RmsNorm和RoPE,将计算结果存入cache
函数原型
每个算子分为,必须先调用“aclnnMlaPreprocessGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPreprocess”接口执行计算。
aclnnStatus aclnnMlaPreprocessGetWorkspaceSize(const aclTensor *input, const aclTensor *gamma0, const aclTensor *beta0, const aclTensor *quantScale0, const aclTensor *quantOffset0,const aclTensor *wdqkv, const aclTensor *deScale0, const aclTensor *bias0, const aclTensor *gamma1, const aclTensor *beta1, const aclTensor *quantScale1, const aclTensor *quantOffset1, const aclTensor *wuq, const aclTensor *deScale1, const aclTensor *bias1, const aclTensor *gamma2, const aclTensor *cos, const aclTensor *sin, const aclTensor *wuk, const aclTensor *kvCache, const aclTensor *kvCacheRope, const aclTensor *slotmapping, const aclTensor *ctkvScale, const aclTensor *qNopeScale, int64_t wdqDim, int64_t qRopeDim, int64_t kRopeDim, float epsilon, int64_t qRotaryCoeff, int64_t kRotaryCoeff, bool transposeWdq, bool transposeWuq, bool transposeWuk, int64_t cacheMode, int64_t quantMode, bool doRmsNorm, int64_t wdkvSplitCount, aclTensor *qOut, aclTensor *kvCacheOut, aclTensor *qRopeOut, aclTensor *krCacheOut, uint64_t *workspaceSize, aclOpExecutor **executor)aclnnStatus aclnnMlaPreprocess(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
aclnnMlaPreprocessGetWorkspaceSize
参数说明:
input(aclTensor*,计算输入):Device侧的aclTensor,用于计算Query和Key的x,shape为[tokenNum,hiddenSize],dtype支持FLOAT16和BFLOAT16,支持ND格式。
gamma0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm计算中的γ参数,shape为[hiddenSize],dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。
beta0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm计算中的β参数,shape为[hiddenSize],dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。
quantScale0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm公式中量化缩放的参数,shape为[1],dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。
quantOffset0(aclTensor*,计算输入):Device侧的aclTensor,首次RmsNorm公式中的量化偏移参数,shape为[1],dtype支持INT8,支持NZ格式。
wdqkv(aclTensor*,计算输入):Device侧的aclTensor,与输入首次做矩阵乘的降维矩阵,shape为[2112,hiddenSize],dtype支持INT8和BFLOAT16,支持ND格式。
deScale0(aclTensor*,计算输入):Device侧的aclTensor,输入首次做矩阵乘的降维矩阵中的系数,shape为[2112],dtype支持INT64和FLOAT,与input的dtype对应,input输入dtype为FLOAT16支持INT64,输入BFLOAT16时支持FLOAT。支持ND格式。
bias0(aclTensor*,计算输入):Device侧的aclTensor,输入首次做矩阵乘的降维矩阵中的系数,shape为[2112],dtype支持INT32,支持ND格式。支持传入空tensor,quantMode为1、3时不传入。
gamma1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm计算中的γ参数,shape为[1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。
beta1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm计算中的β参数,shape为[1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。
quantScale1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm公式中量化缩放的参数,shape为[1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND格式。仅在quantMode为0时传入。
quantOffset1(aclTensor*,计算输入):Device侧的aclTensor,第二次RmsNorm公式中的量化偏移参数,shape为[1]。dtype支持INT8,支持ND格式。仅在quantMode为0时传入。
wuq(aclTensor*,计算输入):Device侧的aclTensor,权重矩阵,shape为[headNum * 192,1536]。dtype支持INT8和BFLOAT16,支持NZ格式。
deScale1(aclTensor*,计算输入):Device侧的aclTensor,参与wuq矩阵乘的系数,shape为[headNum*192,1536]。dtype支持INT64和FLOAT,支持ND格式。input输入dtype为FLOAT16支持INT64,输入BFLOAT16时支持FLOAT。
bias1(aclTensor*,计算输入):Device侧的aclTensor,参与wuq矩阵乘的系数,shape为[[headNum*192]]。dtype支持INT32,支持NZ格式。quantMode为1、3时不传入。
gamma2(aclTensor*,计算输入):Device侧的aclTensor,参与RmsNormAndreshapeAndCache计算的γ参数,shape为[512]。dtype支持BLOAT16和BFLOAT16,与input保持一致,支持ND格式。
cos(aclTensor*,计算输入):Device侧的aclTensor,表示用于计算旋转位置编码的正弦参数矩阵,shape为[tokenNum,64]。dtype支持INT8,支持NZ格式。
sin(aclTensor*,计算输入):Device侧的aclTensor,表示用于计算旋转位置编码的余弦参数矩阵,shape为[tokenNum,64]。dtype支持INT8,支持NZ格式。
wuk(aclTensor*,计算输入):Device侧的aclTensor,表示计算Key的上采样权重,shape为[headNum * 192, 1536]。dtype支持FLOAT16和BFLOAT16,与input保持一致,支持ND/NZ格式。ND格式时的shape为[headNum,128,512],NZ格式时的shape为[headNum,32,128,16]。
kvCache(aclTensor*,计算输入):Device侧的aclTensor,与输出的kvCacheOut为同一tensor,输入格式随cacheMode变化。
- cacheMode为0:shape为[blockNum,blockSize,1,576],dtype与input保持一致,为ND。
 - cacheMode为1:shape为[blockNum,blockSize,1,512],tensor的shape为拆分情况,dtype与input保持一致,为ND。
 - cacheMode为2:shape为[blockNum,headNum*512/32,block_size,32],dtype为int8,为NZ。
 - cacheMode为3:shape为[blockNum,headNum*512/16,block_size,16],dtype与input保持一致,为NZ。
 
kvCacheRope(aclTensor*,计算输入)Device侧的aclTensor,可选参数,支出传入空指针。与输出的krCacheOut为同一tensor,输入格式随cacheMode变化。
slotmapping(aclTensor*,计算输入):Device侧的aclTensor,表示用于存储kv_cache和kr_cache的索引,shape为[tokenNum]。dtype支持INT32,支持ND格式。
ctkvScale(aclTensor*,计算输入):Device侧的aclTensor,输出量化处理中参与计算的系数,仅在cacheMode为2时传入,shape为[1]。dtype支持BLOAT16和BFLOAT16,与input保持一致,支持ND格式。
qNopeScale(aclTensor*,计算输入):Device侧的aclTensor,输出量化处理中参与计算的系数,仅在cacheMode为2时传入,shape为[headNum]。dtype支持BLOAT16和BFLOAT16,与input保持一致,支持ND格式。
wdqDim(int64_t,计算输入):表示经过matmul后拆分的dim大小。预留参数,目前只支持1536。
qRopeDim(int64_t,计算输入):表示q传入RoPE的dim大小。预留参数,目前只支持64。
kRopeDim(int64_t,计算输入):表示k传入RoPE的dim大小。预留参数,目前只支持64。
epsilon(float,计算输入):表示加在分母上防止除0。
qRotaryCoeff(int64_t,计算输入):表示q旋转系数。预留参数,目前只支持2。
kRotaryCoeff(int64_t,计算输入):表示k旋转系数。预留参数,目前只支持2。
transposeWdq(bool,计算输入):表示wdq是否转置。预留参数,目前只支持true。
transposeWuq(bool,计算输入):表示wuq是否转置。预留参数,目前只支持true。
transposeWuk(bool,计算输入):表示wuk是否转置。预留参数,目前只支持true。
cacheMode(int64_t,计算输入):表示指定cache的类型,取值范围[0, 3]。
- 0:kcache和q均经过拼接后输出。
 - 1:输出的kvCacheOut拆分为kvCacheOut和krCacheOut,qOut拆分为qOut和qRopeOut。
 - 2:krope和ctkv转为NZ格式输出,ctkv和qnope经过per_head静态对称量化为int8类型。
 - 3:krope和ctkv转为NZ格式输出。
 
quantMode(int64_t,计算输入):表示指定RmsNorm量化的类型,取值范围[0, 3]。
- 0:per_tensor静态非对称量化,默认量化类型。
 - 1:per_token动态对称量化,未实现。
 - 2:per_token动态非对称量化,未实现。
 - 3:不量化,浮点输出,未实现。
 
doRmsNorm(bool,计算输入):表示是否对input输入进行RmsNormQuant操作,false表示不操作,true表示进行操作。预留参数,目前只支持true。
wdkvSplitCount(int64_t,计算输入):表示指定wdkv拆分的个数,支持[1-3],分别表示不拆分、拆分为2个、拆分为3个降维矩阵。预留参数,目前只支持1。
qOut(aclTensor*,计算输出):计算输出,表示Query的输出tensor,对应计算流图中右侧经过NOPE和矩阵乘后的输出,shape和dtype随cacheMode变化。
kvCacheOut(aclTensor*,计算输出):计算输出,表示Key经过ReshapeAndCache后的输出,shape和dtype随cacheMode变化。
qRopeOut(aclTensor*,计算输出):计算输出,表示Query经过旋转编程后的输出,shape和dtype随cacheMode变化。
krCacheOut(aclTensor*,计算输出):表示Key经过RoPE和ReshapeAndCache后的输出,shape和随cacheMode变化,
workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。
executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
返回值:
[object Object]
aclnnMlaPreprocess
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
 - workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMlaPreprocessGetWorkspaceSize获取。
 - executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
 - stream(aclrtStream,入参):指定执行任务的Stream。
 
返回值:
约束说明
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
 - shape格式字段含义及约束
- tokenNum:tokenNum 表示输入样本批量大小,取值范围:0~256
 - hiddenSize:hiddenSize 表示隐藏层的大小,取值固定为:2048-10240,为256的倍数
 - headNum:表示多头数,取值范围:16、32、64、128
 - blockNum:PagedAttention场景下的块数,取值范围:192
 - blockSize:PagedAttention场景下的块大小,取值范围:128
 - 当wdqkv和wuq的数据类型为BFLOAT16时,输入input也需要为BFLOAT16,且hiddenSize只支持6144