MlaPreprocessOperation
 支持通过C接口直调接入PyTorch,在整网中进行亲和算子替换。
定义
atb::Status AtbMLAPreprocessGetWorkspaceSize(
            const aclTensor *input, const aclTensor *gamma0, const aclTensor *beta0, const aclTensor *quantScale0,
            const aclTensor *quantOffset0, const aclTensor *wdqkv, const aclTensor *deScale0, const aclTensor *bias0,
            const aclTensor *gamma1, const aclTensor *beta1, const aclTensor *quantScale1, const aclTensor *quantOffset1,
            const aclTensor *wuq, const aclTensor *deScale1, const aclTensor *bias1, const aclTensor *gamma2,
            const aclTensor *cos, const aclTensor *sin, const aclTensor *wuk, const aclTensor *kvCache,
            const aclTensor *kvCacheRope, const aclTensor *slotmapping, const aclTensor *ctkvScale, const aclTensor *qNopeScale,
            uint32_t wdqDim, uint32_t qRopeDim, uint32_t kRopeDim, float epsilon, uint32_t qRotaryCoeff, uint32_t kRotaryCoeff,
            bool transposeWdq, bool transposeWuq, bool transposeWuk, uint8_t cacheMode, uint16_t quantMode, aclTensor *qOut0,
            aclTensor *kvCacheOut0, aclTensor *qOut1, aclTensor *kvCacheOut1, uint64_t *workspaceSize, atb::Operation **op,
            atb::Context *context);
atb::Status AtbMLAPreprocess(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbMLAPreprocessGetWorkspaceSize成员
参数  | 
标量/张量  | 
维度  | 
数据类型  | 
格式  | 
默认值  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|---|---|
input  | 
张量  | 
[tokenNum, 7168]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor。  | 
gamma0  | 
张量  | 
[7168]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
beta0  | 
张量  | 
[7168]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
quantScale0  | 
张量  | 
[1]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。仅在quantMode为0时传入,数据类型与input一致。  | 
quantOffset0  | 
张量  | 
[1]  | 
int8  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。仅在quantMode为0时传入。  | 
wdqkv  | 
张量  | 
[1,224,2112,32]  | 
int8  | 
NZ  | 
-  | 
是  | 
输入tensor。  | 
deScale0  | 
张量  | 
[2112]  | 
int64/float  | 
ND  | 
-  | 
是  | 
输入tensor,input为float16时为int64。input为bf16时为float。  | 
bias0  | 
张量  | 
[2112]  | 
int32  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。quantMode为1、3时传入nullptr。  | 
gamma1  | 
张量  | 
[1536]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
beta1  | 
张量  | 
[1536]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
quantScale1  | 
张量  | 
[1]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。仅在quantMode为0时传入,数据类型与input一致。  | 
quantOffset1  | 
张量  | 
[1]  | 
int8  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。仅在quantMode为0时传入。  | 
wuq  | 
张量  | 
[1,48,headNum*192,32]  | 
int8  | 
NZ  | 
-  | 
是  | 
输入tensor。  | 
deScale1  | 
张量  | 
[headNum*192]  | 
int64/float  | 
ND  | 
-  | 
是  | 
输入tensor,input为float16时为int64,input为bf16时为float。  | 
bias1  | 
张量  | 
[headNum*192]  | 
int32  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。quantMode为1、3时传入nullptr。  | 
gamma2  | 
张量  | 
[512]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
cos  | 
张量  | 
[tokenNum,64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
sin  | 
张量  | 
[tokenNum,64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
wuk  | 
张量  | 
  | 
float16/bf16  | 
ND/NZ  | 
-  | 
是  | 
输入tensor,数据类型与input一致。  | 
kvCache  | 
张量  | 
float16/bf16/int8  | 
ND/NZ  | 
-  | 
是  | 
输入tensor。数据类型与input一致。 cacheMode为1时,tensor的shape为拆分情况。 cacheMode为2时,格式为NZ,数据类型为int8。 cacheMode为3时,格式为NZ。  | 
|
kvCacheRope  | 
张量  | 
float16/bf16  | 
ND/NZ  | 
-  | 
是  | 
输入tensor,支持传入空tensor。cacheMode不为0时传入,数据类型与input一致。 cacheMode为2或3时,格式为NZ。  | 
|
slotmapping  | 
张量  | 
[tokenNum]  | 
int32  | 
ND  | 
-  | 
是  | 
输入tensor。  | 
ctkvScale  | 
张量  | 
[1]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。cacheMode为2时传入,数据类型与input一致。  | 
qNopeScale  | 
张量  | 
[headNum]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,支持传入空tensor。cacheMode为2时传入,数据类型与input一致。  | 
wdqDim  | 
标量  | 
-  | 
uint32_t  | 
-  | 
0  | 
是  | 
经过matmul后拆分的dim大小。  | 
qRopeDim  | 
标量  | 
-  | 
uint32_t  | 
-  | 
0  | 
是  | 
q传入rope的dim大小。  | 
kRopeDim  | 
标量  | 
-  | 
uint32_t  | 
-  | 
0  | 
是  | 
k传入rope的dim大小。  | 
epsilon  | 
标量  | 
-  | 
float  | 
-  | 
1e-5  | 
是  | 
加在分母上防止除0。  | 
qRotaryCoeff  | 
标量  | 
-  | 
int32_t  | 
-  | 
2  | 
是  | 
q旋转系数。  | 
kRotaryCoeff  | 
标量  | 
-  | 
int32_t  | 
-  | 
2  | 
是  | 
k旋转系数。  | 
transposeWdq  | 
标量  | 
-  | 
bool  | 
-  | 
true  | 
是  | 
wdq是否转置。  | 
transposeWuq  | 
标量  | 
-  | 
bool  | 
-  | 
true  | 
是  | 
wuq是否转置。  | 
transposeWuk  | 
标量  | 
-  | 
bool  | 
-  | 
true  | 
是  | 
wuk是否转置。  | 
cacheMode  | 
标量  | 
-  | 
uint8_t  | 
-  | 
0  | 
是  | 
0:KVCACHE 1:KROPE_CTKV 2:INT8_NZCACHE 3:NZCACHE  | 
quantMode  | 
标量  | 
-  | 
uint16_t  | 
-  | 
0  | 
是  | 
0: PER_TENSOR_QUANT_ASYMM 1: PER_TOKEN_QUANT_SYMM 2: PER_TOKEN_QUANT_ASYMM 3: UNQUANT  | 
qOut0  | 
张量  | 
float16/bf16/int8  | 
ND  | 
-  | 
是  | 
输出tensor,数据类型与input一致。cacheMode为2时数据类型为int8。  | 
|
kvCacheOut0  | 
张量  | 
float16/bf16/int8  | 
ND/NZ  | 
-  | 
是  | 
输出tensor,数据类型与input一致。 cacheMode为2时,数据类型为int8,格式为NZ。 cacheMode为3时,格式为NZ。  | 
|
qOut1  | 
张量  | 
[tokenNum,headNum,64]  | 
float16/bf16  | 
ND  | 
-  | 
否  | 
cacheMode不为0时输出此tensor。数据类型与input一致。  | 
kvCacheOut1  | 
张量  | 
float16/bf16  | 
ND/NZ  | 
-  | 
否  | 
cacheMode不为0时输出此tensor。数据类型与input一致。 cacheMode为2时,数据格式为NZ。 cacheMode为3时,数据格式为NZ。  |