MLA
 支持通过C接口直调接入PyTorch,在整网中进行亲和算子替换。
定义
atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRope, const aclTensor *ctKV,
                                   const aclTensor *kRope, const aclTensor *blockTables, const aclTensor *contextLens,
                                   const aclTensor *mask, const aclTensor *qSeqLen, const aclTensor *qkDescale,
                                   const aclTensor *pvDescale, int32_t headNum, float qkScale, int32_t kvHeadNum,
                                   int maskType, int calcType, uint8_t cacheMode, aclTensor *attenOut, aclTensor *ise,
                                   uint64_t *workspaceSize, atb::Operation **op, atb::Context *context);
atb::Status AtbMLA(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbMLAGetWorkspaceSize成员
参数  | 
标量/张量  | 
维度  | 
数据类型  | 
格式  | 
默认值  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|---|---|
qNope  | 
张量  | 
[num_tokens, num_heads, 512]  | 
float16/bf16/int8  | 
ND  | 
-  | 
是  | 
无位置编码query。  | 
qRope  | 
张量  | 
[num_tokens, num_heads, 64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
旋转位置编码query。  | 
ctKV  | 
张量  | 
float16/bf16/int8  | 
ND/NZ  | 
-  | 
是  | 
无位置编码ctkv。 
  | 
|
kRope  | 
张量  | 
float16/bf16  | 
ND/NZ  | 
-  | 
是  | 
旋转位置编码k。 cacheMode为2或3:NZ  | 
|
blockTables  | 
张量  | 
[batch, max_num_blocks_per_query]  | 
int32  | 
ND  | 
-  | 
是  | 
每个query的kvcache的block映射表。  | 
contextLens  | 
张量  | 
[batch]  | 
int32  | 
ND  | 
-  | 
是  | 
每个query对应的上下文长度,kseqlen。host侧tensor。  | 
mask  | 
张量  | 
float16/bf16  | 
ND  | 
-  | 
否  | 
注意力掩码,maskType为0时传入nullptr。  | 
|
qseqlen  | 
张量  | 
[batch]  | 
int32  | 
ND  | 
-  | 
否  | 
calcType为1时传入,每个batch对应的seqLen。  | 
qkDescale  | 
张量  | 
[num_heads]  | 
float  | 
ND  | 
-  | 
否  | 
cacheMode为2时传入。  | 
pvDescale  | 
张量  | 
[num_heads]  | 
float  | 
ND  | 
-  | 
否  | 
cacheMode为2时传入。  | 
headNum  | 
标量  | 
-  | 
int32_t  | 
-  | 
0  | 
是  | 
query头数量。  | 
qkScale  | 
标量  | 
-  | 
float  | 
-  | 
1.0  | 
是  | 
Q*K^T后乘以的缩放系数。  | 
kvHeadNum  | 
标量  | 
-  | 
int32_t  | 
-  | 
0  | 
是  | 
kv头数量。  | 
maskType  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
否  | 
mask类型。 
  | 
calcType  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
否  | 
计算类型。 
  | 
cacheMode  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
是  | 
  | 
attenOut  | 
张量  | 
[num_tokens, num_heads, 512]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
attention输出。  | 
lse  | 
张量  | 
[num_tokens, num_heads, 1]  | 
float16/bf16  | 
ND  | 
-  | 
否  | 
LSE输出,calcType不为CALC_TYPE_RING时可以传入nullptr。  | 
原始接口
请参见MLA输入输出。