MLA
定义
atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRope, const aclTensor *ctKV, const aclTensor *kRope, const aclTensor *blockTables, const aclTensor *contextLens, const aclTensor *mask, const aclTensor *qSeqLen, const aclTensor *qkDescale, const aclTensor *pvDescale, int32_t headNum, float qkScale, int32_t kvHeadNum, int maskType, int calcType, uint8_t cacheMode, aclTensor *attenOut, aclTensor *ise, uint64_t *workspaceSize, atb::Operation **op, atb::Context *context); atb::Status AtbMLA(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbMLAGetWorkspaceSize成员
参数 |
标量/张量 |
维度 |
数据类型 |
格式 |
默认值 |
是否必选 |
描述 |
---|---|---|---|---|---|---|---|
qNope |
张量 |
[num_tokens, num_heads, 512] |
float16/bf16/int8 |
ND |
- |
是 |
无位置编码query。 |
qRope |
张量 |
[num_tokens, num_heads, 64] |
float16/bf16 |
ND |
- |
是 |
旋转位置编码query。 |
ctKV |
张量 |
float16/bf16/int8 |
ND/NZ |
- |
是 |
无位置编码ctkv。
|
|
kRope |
张量 |
float16/bf16 |
ND/NZ |
- |
是 |
旋转位置编码k。 cacheMode为2或3:NZ |
|
blockTables |
张量 |
[batch, max_num_blocks_per_query] |
int32 |
ND |
- |
是 |
每个query的kvcache的block映射表。 |
contextLens |
张量 |
[batch] |
int32 |
ND |
- |
是 |
每个query对应的上下文长度,kseqlen。host侧tensor。 |
mask |
张量 |
float16/bf16 |
ND |
- |
否 |
注意力掩码,maskType为0时传入nullptr。 |
|
qseqlen |
张量 |
[batch] |
int32 |
ND |
- |
否 |
calcType为1时传入,每个batch对应的seqLen。 |
qkDescale |
张量 |
[num_heads] |
float |
ND |
- |
否 |
cacheMode为2时传入。 |
pvDescale |
张量 |
[num_heads] |
float |
ND |
- |
否 |
cacheMode为2时传入。 |
headNum |
标量 |
- |
int32_t |
- |
0 |
是 |
query头数量。 |
qkScale |
标量 |
- |
float |
- |
1.0 |
是 |
Q*K^T后乘以的缩放系数。 |
kvHeadNum |
标量 |
- |
int32_t |
- |
0 |
是 |
kv头数量。 |
maskType |
标量 |
- |
int |
- |
0 |
否 |
mask类型。
|
calcType |
标量 |
- |
int |
- |
0 |
否 |
计算类型。
|
cacheMode |
标量 |
- |
int |
- |
0 |
是 |
|
attenOut |
张量 |
[num_tokens, num_heads, 512] |
float16/bf16 |
ND |
- |
是 |
attention输出。 |
lse |
张量 |
[num_tokens, num_heads, 1] |
float16/bf16 |
ND |
- |
否 |
LSE输出,calcType不为CALC_TYPE_RING时可以传入nullptr。 |
原始接口
请参见MLA输入输出。