aclnnMlaPreprocessV2
产品支持情况
功能说明
接口功能:推理场景,Multi-Head Latent Attention前处理的计算。主要计算过程如下:
- 首先对输入 RmsNormQuant后乘以进行下采样后分为通路1和通路2。
- 通路1做RmsNormQuant后乘以后再分为通路3和通路4。
- 通路3后乘以后输出。
- 通路4后经过旋转位置编码后输出。
- 通路2拆分为通路5和通路6。
- 通路5经过RmsNorm后传入Cache中得到。
- 通路6经过旋转位置编码后传入另一个Cache中得到。
计算流程图
计算公式:
RmsNormQuant公式
Query计算公式,包括W^{DQKV}矩阵乘、W^{UK}矩阵乘、RmsNormQuant和ROPE旋转位置编码处理
Key计算公式,包括RmsNorm和rope,将计算结果存入cache
函数原型
每个算子分为,必须先调用“aclnnMlaPreprocessV2GetWorkspaceSize”接口获取入参并根据流程计算所需workspace大小,再调用“aclnnMlaPreprocessV2”接口执行计算。
[object Object]
[object Object]
aclnnMlaPreprocessV2GetWorkspaceSize
aclnnMlaPreprocessV2
约束说明
- 确定性计算:
- aclnnMlaPreprocessV2默认确定性实现。
- shape格式字段含义及约束
- tokenNum:tokenNum 表示输入样本批量大小,取值范围:0~256
- hiddenSize:hiddenSize 表示隐藏层的大小,取值固定为:2048-10240,为256的倍数
- headNum:表示多头数,取值范围:16、32、64、128
- blockNum:PagedAttention场景下的块数,取值范围:192
- blockSize:PagedAttention场景下的块大小,取值范围:128
- 当wdqkv和wuq的数据类型为bfloat16时,输入input也需要为bfloat16,且hiddenSize只支持6144,cacheMode只支持0和1
调用示例
[object Object]