昇腾社区首页
中文
注册

MlaPreprocessOperation

定义

atb::Status AtbMLAPreprocessGetWorkspaceSize(
            const aclTensor *input, const aclTensor *gamma0, const aclTensor *beta0, const aclTensor *quantScale0,
            const aclTensor *quantOffset0, const aclTensor *wdqkv, const aclTensor *deScale0, const aclTensor *bias0,
            const aclTensor *gamma1, const aclTensor *beta1, const aclTensor *quantScale1, const aclTensor *quantOffset1,
            const aclTensor *wuq, const aclTensor *deScale1, const aclTensor *bias1, const aclTensor *gamma2,
            const aclTensor *cos, const aclTensor *sin, const aclTensor *wuk, const aclTensor *kvCache,
            const aclTensor *kvCacheRope, const aclTensor *slotmapping, const aclTensor *ctkvScale, const aclTensor *qNopeScale,
            uint32_t wdqDim, uint32_t qRopeDim, uint32_t kRopeDim, float epsilon, uint32_t qRotaryCoeff, uint32_t kRotaryCoeff,
            bool transposeWdq, bool transposeWuq, bool transposeWuk, uint8_t cacheMode, uint16_t quantMode, aclTensor *qOut0,
            aclTensor *kvCacheOut0, aclTensor *qOut1, aclTensor *kvCacheOut1, uint64_t *workspaceSize, atb::Operation **op,
            atb::Context *context);
atb::Status AtbMLAPreprocess(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);

AtbMLAPreprocessGetWorkspaceSize成员

参数

标量/张量

维度

数据类型

格式

默认值

是否必选

描述

input

张量

[tokenNum, 7168]

float16/bf16

ND

-

输入tensor。

gamma0

张量

[7168]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

beta0

张量

[7168]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

quantScale0

张量

[1]

float16/bf16

ND

-

输入tensor,支持传入空tensor。仅在quantMode为0时传入,数据类型与input一致。

quantOffset0

张量

[1]

int8

ND

-

输入tensor,支持传入空tensor。仅在quantMode为0时传入。

wdqkv

张量

[1,224,2112,32]

int8

NZ

-

输入tensor。

deScale0

张量

[2112]

int64/float

ND

-

输入tensor,input为float16时为int64。input为bf16时为float。

bias0

张量

[2112]

int32

ND

-

输入tensor,支持传入空tensor。quantMode为1、3时传入nullptr。

gamma1

张量

[1536]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

beta1

张量

[1536]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

quantScale1

张量

[1]

float16/bf16

ND

-

输入tensor,支持传入空tensor。仅在quantMode为0时传入,数据类型与input一致。

quantOffset1

张量

[1]

int8

ND

-

输入tensor,支持传入空tensor。仅在quantMode为0时传入。

wuq

张量

[1,48,headNum*192,32]

int8

NZ

-

输入tensor。

deScale1

张量

[headNum*192]

int64/float

ND

-

输入tensor,input为float16时为int64,input为bf16时为float。

bias1

张量

[headNum*192]

int32

ND

-

输入tensor,支持传入空tensor。quantMode为1、3时传入nullptr。

gamma2

张量

[512]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

cos

张量

[tokenNum,64]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

sin

张量

[tokenNum,64]

float16/bf16

ND

-

输入tensor,数据类型与input一致。

wuk

张量

  • ND:[headNum,128,512]
  • NZ:[headNum,32,128,16]

float16/bf16

ND/NZ

-

输入tensor,数据类型与input一致。

kvCache

张量

  • cacheMode为0:

    [blockNum,blockSize,1,576]

  • cacheMode为1:

    [blockNum,blockSize,1,512]

  • cacheMode为2:

    [blockNum, headNum*512/32,block_size, 32]

  • cacheMode为3:

    [blockNum, headNum*512/16,block_size, 16]

float16/bf16/int8

ND/NZ

-

输入tensor。数据类型与input一致。

cacheMode为1时,tensor的shape为拆分情况。

cacheMode为2时,格式为NZ,数据类型为int8。

cacheMode为3时,格式为NZ。

kvCacheRope

张量

  • cacheMode为1:

    [blockNum,blockSize,1,64]

  • cacheMode为2或3:

    [blockNum, headNum*64 / 16 ,block_size, 16]

float16/bf16

ND/NZ

-

输入tensor,支持传入空tensor。cacheMode不为0时传入,数据类型与input一致。

cacheMode为2或3时,格式为NZ。

slotmapping

张量

[tokenNum]

int32

ND

-

输入tensor。

ctkvScale

张量

[1]

float16/bf16

ND

-

输入tensor,支持传入空tensor。cacheMode为2时传入,数据类型与input一致。

qNopeScale

张量

[headNum]

float16/bf16

ND

-

输入tensor,支持传入空tensor。cacheMode为2时传入,数据类型与input一致。

wdqDim

标量

-

uint32_t

-

0

经过matmul后拆分的dim大小。

qRopeDim

标量

-

uint32_t

-

0

q传入rope的dim大小。

kRopeDim

标量

-

uint32_t

-

0

k传入rope的dim大小。

epsilon

标量

-

float

-

1e-5

加在分母上防止除0。

qRotaryCoeff

标量

-

int32_t

-

2

q旋转系数。

kRotaryCoeff

标量

-

int32_t

-

2

k旋转系数。

transposeWdq

标量

-

bool

-

true

wdq是否转置。

transposeWuq

标量

-

bool

-

true

wuq是否转置。

transposeWuk

标量

-

bool

-

true

wuk是否转置。

cacheMode

标量

-

uint8_t

-

0

0:KVCACHE

1:KROPE_CTKV

2:INT8_NZCACHE

3:NZCACHE

quantMode

标量

-

uint16_t

-

0

0: PER_TENSOR_QUANT_ASYMM

1: PER_TOKEN_QUANT_SYMM

2: PER_TOKEN_QUANT_ASYMM

3: UNQUANT

qOut0

张量

  • cacheMode为0:

    [tokenNum,headNum,576]

  • cacheMode为1或2:

    [tokenNum,headNum,512]

float16/bf16/int8

ND

-

输出tensor,数据类型与input一致。cacheMode为2时数据类型为int8。

kvCacheOut0

张量

  • cacheMode为0:

    [blockNum,blockSize,1,576]

  • cacheMode为1:

    [blockNum,blockSize,1,512]

  • cacheMode为2:

    [blockNum, headNum*512/32,block_size, 32]

  • cacheMode为3:

    [blockNum, headNum*512/16,block_size, 16]

float16/bf16/int8

ND/NZ

-

输出tensor,数据类型与input一致。

cacheMode为2时,数据类型为int8,格式为NZ。

cacheMode为3时,格式为NZ。

qOut1

张量

[tokenNum,headNum,64]

float16/bf16

ND

-

cacheMode不为0时输出此tensor。数据类型与input一致。

kvCacheOut1

张量

  • cacheMode为1:

    [blockNum,blockSize,1,64]

  • cacheMode为2:

    [blockNum, headNum*64 / 16 ,block_size, 16]

float16/bf16

ND/NZ

-

cacheMode不为0时输出此tensor。数据类型与input一致。

cacheMode为2时,数据格式为NZ。

cacheMode为3时,数据格式为NZ。