MLA Prefill
 支持通过C接口直调接入PyTorch,在整网中进行亲和算子替换。
定义
atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *qRope, const aclTensor *k,
                                          const aclTensor *kRope, const aclTensor *v, const aclTensor *qSeqLen, const aclTensor *kvSeqLen,
                                          const aclTensor *mask, int32_t headNum, float qkScale, int32_t kvHeadNum,
                                          int maskType, uint8_t cacheMode, aclTensor *attenOut,
                                          uint64_t *workspaceSize, atb::Operation **op, atb::Context *context);
atb::Status AtbMLAPreFill(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbMLAPreFillGetWorkspaceSize成员
参数  | 
标量/张量  | 
维度  | 
数据类型  | 
格式  | 
默认值  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|---|---|
Q  | 
张量  | 
[nTokens, qhead * embeddimV] [nTokens, qhead, embeddimV]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
query。  | 
QRope  | 
张量  | 
[nTokens, qhead * 64] [nTokens, qhead, 64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
旋转位置编码query。  | 
K  | 
张量  | 
[batch, max_seq, kv_head * embeddimV] [batch*seq, kv_head, embeddimV]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
key。  | 
kRope  | 
张量  | 
[batch, max_seq, kv_head * 64] [batch*seq, kv_head, 64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
旋转位置编码key。  | 
V  | 
张量  | 
[batch, max_seq, kv_head * embeddimV] [batch*seq, kv_head, embeddimV]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
value。  | 
qSeqLen  | 
张量  | 
[batch]  | 
int32  | 
ND  | 
-  | 
是  | 
输入tensor。  | 
kvSeqLen  | 
张量  | 
[batch]  | 
int32  | 
ND  | 
-  | 
是  | 
输入tensor。  | 
Mask  | 
张量  | 
None [512,512]  | 
float16/bf16  | 
ND  | 
-  | 
否  | 
注意力掩码,UNDEFINED 或 MASK_TYPE_CAUSAL_MASK 时可以置为nullptr,MASK_TYPE_MASK_FREE时输入tensor。  | 
headNum  | 
标量  | 
-  | 
int32_t  | 
-  | 
0  | 
是  | 
query头数量。  | 
qkScale  | 
标量  | 
-  | 
float  | 
-  | 
1.0  | 
是  | 
Q*K^T后乘以的缩放系数。  | 
kvHeadNum  | 
标量  | 
-  | 
int32_t  | 
-  | 
0  | 
是  | 
kv头数量。  | 
maskType  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
否  | 
mask类型。 UNDEFINED:默认值,全0的mask; MASK_TYPE_SPEC:qseqlen > 1时的mask。 MASK_TYPE_MASK_FREE:maskfree功能。  | 
calcType  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
否  | 
计算类型。 
  | 
cacheMode  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
是  | 
  | 
attenOut  | 
张量  | 
[nTokens, qhead * embeddimV] [ntokens, kvhead, embedv]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
attention输出。  |