aclnnMlaProlog
产品支持情况
产品 | 是否支持 |
---|---|
[object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object] | √ |
[object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object] | √ |
[object Object]Atlas 200I/500 A2 推理产品[object Object] | × |
[object Object]Atlas 推理系列产品[object Object] | × |
[object Object]Atlas 训练系列产品[object Object] | × |
功能说明
算子功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程分为四路,首先对输入乘以进行下采样和RmsNorm后分为两路,第一路乘以和经过两次上采样后得到;第二路乘以后经过旋转位置编码(ROPE)得到;第三路是输入乘以进行下采样和RmsNorm后传入Cache中得到;第四路是输入乘以后经过旋转位置编码后传入另一个Cache中得到。
计算公式:
RmsNorm公式
Query计算公式,包括下采样,RmsNorm和两次上采样
对Query的进行ROPE旋转位置编码
Key计算公式,包括下采样和RmsNorm,将计算结果存入cache
对Key进行ROPE旋转位置编码,并将结果存入cache
函数原型
每个算子分为undefined,必须先调用“aclnnMlaPrologGetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaProlog”接口执行计算。
aclnnStatus aclnnMlaPrologGetWorkspaceSize(const aclTensor *tokenX, const aclTensor *weightDq, const aclTensor *weightUqQr, const aclTensor *weightUk, const aclTensor *weightDkvKr, const aclTensor *rmsnormGammaCq, const aclTensor *rmsnormGammaCkv, const aclTensor *ropeSin, const aclTensor *ropeCos, const aclTensor *cacheIndex, aclTensor *kvCacheRef, aclTensor *krCacheRef, const aclTensor *dequantScaleXOptional, const aclTensor *dequantScaleWDqOptional, const aclTensor *dequantScaleWUqQrOptional, const aclTensor *dequantScaleWDkvKrOptional, const aclTensor *quantScaleCkvOptional, const aclTensor *quantScaleCkrOptional, const aclTensor *smoothScalesCqOptional, double rmsnormEpsilonCq, double rmsnormEpsilonCkv, char *cacheModeOptional, const aclTensor *queryOut, const aclTensor *queryRopeOut, uint64_t *workspaceSize, aclOpExecutor **executor)
aclnnStatus aclnnMlaProlog(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
aclnnMlaPrologGetWorkspaceSize
[object Object]
参数说明:
- tokenX(aclTensor*,计算输入):表示输入的tensor,用于计算Query和Key的x,Device侧的aclTensor。shape支持2维和3维,格式为(T,He)和(B,S,He),dtype支持BFLOAT16,undefined支持ND格式。
- weightDq(aclTensor*,计算输入):表示用于计算Query的下采样权重矩阵,对应公式中的,Device侧的aclTensor。其shape支持2维,格式为(He,Hcq),dtype支持BFLOAT16,undefined支持FRACTAL_NZ格式。
- weightUqQr(aclTensor*,计算输入):表示用于计算Query的上采样权重矩阵和Query的位置编码权重矩阵,对应公式中的和,Device侧的aclTensor。其shape支持2维,格式为(Hcq,N*(D+Dr)),dtype支持BFLOAT16和INT8,undefined支持FRACTAL_NZ格式。
- 当weightUqQr为INT8类型时,weightUqQr是一个per-tensor的量化后的输入,表示当前为部分量化场景。
- 此时若kvCacheRef、krCacheRef为BFLOAT16类型,对应输出为非量化输出,此时dequantScaleWUqQrOptional字段必须传入,smoothScalesCqOptional字段可选传入。
- 此时若kvCacheRef、krCacheRef为INT8类型,对应输出为量化输出,此时dequantScaleWUqQrOptional、quantScaleCkvOptional、quantScaleCkrOptional字段必须传入,smoothScalesCqOptional字段可选传入。
- 当weightUqQr为BFLOAT16类型时,表示当前为非量化场景。
- 此时dequantScaleWUqQrOptional、quantScaleCkvOptional、quantScaleCkrOptional、smoothScalesCqOptional字段不能传入(即为nullptr)。
- 当weightUqQr为INT8类型时,weightUqQr是一个per-tensor的量化后的输入,表示当前为部分量化场景。
- weightUk(aclTensor*,计算输入):表示用于计算Key的上采样权重,对应公式中的,Device侧的aclTensor。其shape支持3维,格式为(N,D,Hckv),dtype支持BFLOAT16,undefined支持ND格式。
- weightDkvKr(aclTensor*,计算输入):表示用于计算Key的下采样权重矩阵和Key的位置编码权重矩阵,对应公式中的和,Device侧的aclTensor。其shape支持2维,格式为(He,Hckv+Dr),dtype支持BFLOAT16,undefined支持FRACTAL_NZ格式。
- rmsnormGammaCq(aclTensor*,计算输入):表示计算的RmsNorm公式中的参数,Device侧的aclTensor。其shape支持1维,格式为(Hcq),dtype支持BFLOAT16,undefined支持ND格式。
- rmsnormGammaCkv(aclTensor*,计算输入):表示计算的RmsNorm公式中的参数,Device侧的aclTensor。其shape支持1维,格式为(Hckv),dtype支持BFLOAT16,undefined支持ND格式。
- ropeSin(aclTensor*,计算输入):表示用于计算旋转位置编码的正弦参数矩阵,Device侧的aclTensor。其shape支持2维和3维,格式为(T,Dr)和(B,S,Dr),dtype支持BFLOAT16,undefined支持ND格式。
- ropeCos(aclTensor*,计算输入):表示用于计算旋转位置编码的余弦参数矩阵,Device侧的aclTensor。其shape支持2维和3维,格式为(T,Dr)和(B,S,Dr),dtype支持BFLOAT16,undefined支持ND格式。
- cacheIndex(aclTensor*,计算输入):表示用于存储kvCache和krCache的索引,Device侧的aclTensor。其shape支持1维和2维,格式为(T)和(B,S),dtype支持INT64,undefined支持ND格式。
- cacheIndex的取值范围为[0,BlockNum*BlockSize),当前不会对cacheIndex传入值的合法性进行校验,需用户自行保证。
- kvCacheRef(aclTensor*,计算输入):表示用于cache索引的aclTensor。其shape支持4维,格式为(BlockNum,BlockSize,Nkv,Hckv),dtype支持BFLOAT16和INT8,undefined支持ND格式。计算结果原地更新,更新后结果对应公式中的。
- krCacheRef(aclTensor*,计算输入):表示用于key位置编码的cache,Device侧的aclTensor。其shape支持4维,格式为(BlockNum,BlockSize,Nkv,Dr),dtype支持BFLOAT16和INT8,undefined支持ND格式。计算结果原地更新,更新后结果对应公式中的。
- dequantScaleXOptional(aclTensor*,计算输入):预留参数,当前版本暂未使用,传入空指针。
- dequantScaleWDqOptional(aclTensor*,计算输入):预留参数,当前版本暂未使用,传入空指针。
- dequantScaleWUqQrOptional(aclTensor*,计算输入):用于对MatmulQcQr矩阵乘后进行反量化操作时的参数,量化参数为per-channel,Device侧的aclTensor。其shape支持2维,格式为(1,N*(D+Dr)),dtype支持FLOAT,undefined支持ND格式。
- dequantScaleWDkvKrOptional(aclTensor*,计算输入):预留参数,当前版本暂未使用,传入空指针。
- quantScaleCkvOptional(aclTensor*,计算输入):用于对输出到KVCache中的数据做量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Hckv),dtype支持FLOAT,undefined支持ND格式。
- quantScaleCkrOptional(aclTensor*,计算输入):用于对输出到KRCache中的数据做量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Dr),dtype支持FLOAT,undefined支持ND格式。
- smoothScalesCqOptional(aclTensor*,计算输入):用于对RmsNormCq输出做动态量化操作时的参数,Device侧的aclTensor。其shape支持2维,格式为(1,Hcq),dtype支持FLOAT,undefined支持ND格式。
- rmsnormEpsilonCq(double,计算输入):表示计算的RmsNorm公式中的参数,用户不特意指定时建议传入1e-05。
- rmsnormEpsilonCkv(double,计算输入):表示计算的RmsNorm公式中的参数,用户不特意指定时建议传入1e-05。
- cacheModeOptional(char*,计算输入):表示kvCache的模式,支持"PA_BSND","PA_NZ",用户不特意指定时建议传入"PA_BSND"。
- queryOut(aclTensor*,计算输出):表示Query的输出tensor,对应公式中的,Device侧的aclTensor。shape支持3维和4维,格式为(T,N,Hckv)和(B,S,N,Hckv),dtype支持BFLOAT16,undefined支持ND格式。
- queryRopeOut(aclTensor*,计算输出):表示Query位置编码的输出tensor,对应公式中的,Device侧的aclTensor。shape支持3维和4维,格式为(T,N,Dr)和(B,S,N,Dr),dtype支持BFLOAT16,undefined支持ND格式。
- workspaceSize(uint64_t*,计算输出):返回需要在Device侧申请的workspace大小。
- executor(aclOpExecutor**,计算输出):返回op执行器,包含了算子计算流程。
返回值:
aclnnStatus:返回状态码,具体参见undefined。
[object Object]
aclnnMlaProlog
参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存地址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnMlaPrologGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的Stream。
返回值:
aclnnStatus:返回状态码,具体参见undefined。
约束说明
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
- B、S、T、Skv值允许一个或多个取0,即Shape与B、S、T、Skv值相关的入参允许传入空Tensor,其余入参不支持传入空Tensor。
- 如果B、S、T取值为0,则queryOut、queryRopeOut输出空Tensor,kvCacheRef、krCacheRef不做更新。
- 如果Skv取值为0,则queryOut、queryRopeOut正常计算,kvCacheRef、krCacheRef不做更新,即输出空Tensor。
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考undefined。
[object Object]