MlaPreprocessOperation
定义
atb::Status AtbMLAPreprocessGetWorkspaceSize( const aclTensor *input, const aclTensor *gamma0, const aclTensor *beta0, const aclTensor *quantScale0, const aclTensor *quantOffset0, const aclTensor *wdqkv, const aclTensor *deScale0, const aclTensor *bias0, const aclTensor *gamma1, const aclTensor *beta1, const aclTensor *quantScale1, const aclTensor *quantOffset1, const aclTensor *wuq, const aclTensor *deScale1, const aclTensor *bias1, const aclTensor *gamma2, const aclTensor *cos, const aclTensor *sin, const aclTensor *wuk, const aclTensor *kvCache, const aclTensor *kvCacheRope, const aclTensor *slotmapping, const aclTensor *ctkvScale, const aclTensor *qNopeScale, uint32_t wdqDim, uint32_t qRopeDim, uint32_t kRopeDim, float epsilon, uint32_t qRotaryCoeff, uint32_t kRotaryCoeff, bool transposeWdq, bool transposeWuq, bool transposeWuk, uint8_t cacheMode, uint16_t quantMode, aclTensor *qOut0, aclTensor *kvCacheOut0, aclTensor *qOut1, aclTensor *kvCacheOut1, uint64_t *workspaceSize, atb::Operation **op, atb::Context *context); atb::Status AtbMLAPreprocess(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbMLAPreprocessGetWorkspaceSize成员
参数 |
标量/张量 |
维度 |
数据类型 |
格式 |
默认值 |
是否必选 |
描述 |
---|---|---|---|---|---|---|---|
input |
张量 |
[tokenNum, 7168] |
float16/bf16 |
ND |
- |
是 |
输入tensor。 |
gamma0 |
张量 |
[7168] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
beta0 |
张量 |
[7168] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
quantScale0 |
张量 |
[1] |
float16/bf16 |
ND |
- |
是 |
输入tensor,支持传入空tensor。仅在quantMode为0时传入,数据类型与input一致。 |
quantOffset0 |
张量 |
[1] |
int8 |
ND |
- |
是 |
输入tensor,支持传入空tensor。仅在quantMode为0时传入。 |
wdqkv |
张量 |
[1,224,2112,32] |
int8 |
NZ |
- |
是 |
输入tensor。 |
deScale0 |
张量 |
[2112] |
int64/float |
ND |
- |
是 |
输入tensor,input为float16时为int64。input为bf16时为float。 |
bias0 |
张量 |
[2112] |
int32 |
ND |
- |
是 |
输入tensor,支持传入空tensor。quantMode为1、3时传入nullptr。 |
gamma1 |
张量 |
[1536] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
beta1 |
张量 |
[1536] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
quantScale1 |
张量 |
[1] |
float16/bf16 |
ND |
- |
是 |
输入tensor,支持传入空tensor。仅在quantMode为0时传入,数据类型与input一致。 |
quantOffset1 |
张量 |
[1] |
int8 |
ND |
- |
是 |
输入tensor,支持传入空tensor。仅在quantMode为0时传入。 |
wuq |
张量 |
[1,48,headNum*192,32] |
int8 |
NZ |
- |
是 |
输入tensor。 |
deScale1 |
张量 |
[headNum*192] |
int64/float |
ND |
- |
是 |
输入tensor,input为float16时为int64,input为bf16时为float。 |
bias1 |
张量 |
[headNum*192] |
int32 |
ND |
- |
是 |
输入tensor,支持传入空tensor。quantMode为1、3时传入nullptr。 |
gamma2 |
张量 |
[512] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
cos |
张量 |
[tokenNum,64] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
sin |
张量 |
[tokenNum,64] |
float16/bf16 |
ND |
- |
是 |
输入tensor,数据类型与input一致。 |
wuk |
张量 |
|
float16/bf16 |
ND/NZ |
- |
是 |
输入tensor,数据类型与input一致。 |
kvCache |
张量 |
float16/bf16/int8 |
ND/NZ |
- |
是 |
输入tensor。数据类型与input一致。 cacheMode为1时,tensor的shape为拆分情况。 cacheMode为2时,格式为NZ,数据类型为int8。 cacheMode为3时,格式为NZ。 |
|
kvCacheRope |
张量 |
float16/bf16 |
ND/NZ |
- |
是 |
输入tensor,支持传入空tensor。cacheMode不为0时传入,数据类型与input一致。 cacheMode为2或3时,格式为NZ。 |
|
slotmapping |
张量 |
[tokenNum] |
int32 |
ND |
- |
是 |
输入tensor。 |
ctkvScale |
张量 |
[1] |
float16/bf16 |
ND |
- |
是 |
输入tensor,支持传入空tensor。cacheMode为2时传入,数据类型与input一致。 |
qNopeScale |
张量 |
[headNum] |
float16/bf16 |
ND |
- |
是 |
输入tensor,支持传入空tensor。cacheMode为2时传入,数据类型与input一致。 |
wdqDim |
标量 |
- |
uint32_t |
- |
0 |
是 |
经过matmul后拆分的dim大小。 |
qRopeDim |
标量 |
- |
uint32_t |
- |
0 |
是 |
q传入rope的dim大小。 |
kRopeDim |
标量 |
- |
uint32_t |
- |
0 |
是 |
k传入rope的dim大小。 |
epsilon |
标量 |
- |
float |
- |
1e-5 |
是 |
加在分母上防止除0。 |
qRotaryCoeff |
标量 |
- |
int32_t |
- |
2 |
是 |
q旋转系数。 |
kRotaryCoeff |
标量 |
- |
int32_t |
- |
2 |
是 |
k旋转系数。 |
transposeWdq |
标量 |
- |
bool |
- |
true |
是 |
wdq是否转置。 |
transposeWuq |
标量 |
- |
bool |
- |
true |
是 |
wuq是否转置。 |
transposeWuk |
标量 |
- |
bool |
- |
true |
是 |
wuk是否转置。 |
cacheMode |
标量 |
- |
uint8_t |
- |
0 |
是 |
0:KVCACHE 1:KROPE_CTKV 2:INT8_NZCACHE 3:NZCACHE |
quantMode |
标量 |
- |
uint16_t |
- |
0 |
是 |
0: PER_TENSOR_QUANT_ASYMM 1: PER_TOKEN_QUANT_SYMM 2: PER_TOKEN_QUANT_ASYMM 3: UNQUANT |
qOut0 |
张量 |
float16/bf16/int8 |
ND |
- |
是 |
输出tensor,数据类型与input一致。cacheMode为2时数据类型为int8。 |
|
kvCacheOut0 |
张量 |
float16/bf16/int8 |
ND/NZ |
- |
是 |
输出tensor,数据类型与input一致。 cacheMode为2时,数据类型为int8,格式为NZ。 cacheMode为3时,格式为NZ。 |
|
qOut1 |
张量 |
[tokenNum,headNum,64] |
float16/bf16 |
ND |
- |
否 |
cacheMode不为0时输出此tensor。数据类型与input一致。 |
kvCacheOut1 |
张量 |
float16/bf16 |
ND/NZ |
- |
否 |
cacheMode不为0时输出此tensor。数据类型与input一致。 cacheMode为2时,数据格式为NZ。 cacheMode为3时,数据格式为NZ。 |