MLA Prefill
定义
atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *qRope, const aclTensor *k, const aclTensor *kRope, const aclTensor *v, const aclTensor *qSeqLen, const aclTensor *kvSeqLen, const aclTensor *mask, int32_t headNum, float qkScale, int32_t kvHeadNum, int maskType, uint8_t cacheMode, aclTensor *attenOut, uint64_t *workspaceSize, atb::Operation **op, atb::Context *context); atb::Status AtbMLAPreFill(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbMLAPreFillGetWorkspaceSize成员
参数 |
标量/张量 |
维度 |
数据类型 |
格式 |
默认值 |
是否必选 |
描述 |
---|---|---|---|---|---|---|---|
Q |
张量 |
[nTokens, qhead * embeddimV] [nTokens, qhead, embeddimV] |
float16/bf16 |
ND |
- |
是 |
query。 |
QRope |
张量 |
[nTokens, qhead * 64] [nTokens, qhead, 64] |
float16/bf16 |
ND |
- |
是 |
旋转位置编码query。 |
K |
张量 |
[batch, max_seq, kv_head * embeddimV] [batch*seq, kv_head, embeddimV] |
float16/bf16 |
ND |
- |
是 |
key。 |
kRope |
张量 |
[batch, max_seq, kv_head * 64] [batch*seq, kv_head, 64] |
float16/bf16 |
ND |
- |
是 |
旋转位置编码key。 |
V |
张量 |
[batch, max_seq, kv_head * embeddimV] [batch*seq, kv_head, embeddimV] |
float16/bf16 |
ND |
- |
是 |
value。 |
qSeqLen |
张量 |
[batch] |
int32 |
ND |
- |
是 |
输入tensor。 |
kvSeqLen |
张量 |
[batch] |
int32 |
ND |
- |
是 |
输入tensor。 |
Mask |
张量 |
None [512,512] |
float16/bf16 |
ND |
- |
否 |
注意力掩码,UNDEFINED 或 MASK_TYPE_CAUSAL_MASK 时可以置为nullptr,MASK_TYPE_MASK_FREE时输入tensor。 |
headNum |
标量 |
- |
int32_t |
- |
0 |
是 |
query头数量。 |
qkScale |
标量 |
- |
float |
- |
1.0 |
是 |
Q*K^T后乘以的缩放系数。 |
kvHeadNum |
标量 |
- |
int32_t |
- |
0 |
是 |
kv头数量。 |
maskType |
标量 |
- |
int |
- |
0 |
否 |
mask类型。 UNDEFINED:默认值,全0的mask; MASK_TYPE_SPEC:qseqlen > 1时的mask。 MASK_TYPE_MASK_FREE:maskfree功能。 |
calcType |
标量 |
- |
int |
- |
0 |
否 |
计算类型。
|
cacheMode |
标量 |
- |
int |
- |
0 |
是 |
|
attenOut |
张量 |
[nTokens, qhead * embeddimV] [ntokens, kvhead, embedv] |
float16/bf16 |
ND |
- |
是 |
attention输出。 |