RingMLAOperation
 支持通过C接口直调接入PyTorch,在整网中进行亲和算子替换。
定义
atb::Status AtbRingMLAGetWorkspaceSize(const aclTensor *querySplit1, const aclTensor *querySplit2,
                                       const aclTensor *keySplit1, const aclTensor *keySplit2, const aclTensor *value,
                                       const aclTensor *mask, const aclTensor *seqLen, const aclTensor *prevOut,
                                       const aclTensor *prevLse, int32_t headNum, int32_t kvHeadNum, float qkScale,
                                       int kernelType, int maskType, int inputLayout, int calcType, aclTensor *output,
                                       aclTensor *softmaxLse, uint64_t *workspaceSize, atb::Operation **op,
                                       atb::Context *context);
atb::Status AtbRingMLA(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbRingMLAGetWorkspaceSize成员
参数  | 
标量/张量  | 
维度  | 
数据类型  | 
格式  | 
默认值  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|---|---|
querySplit1  | 
张量  | 
[qNTokens, headNum, 128]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,无位置编码query矩阵。  | 
querySplit2  | 
张量  | 
[qNTokens, headNum, 64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,query旋转位置编码分量。  | 
keySplit1  | 
张量  | 
[kvNTokens, kvHeadNum, 128]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,无位置编码key矩阵。  | 
keySplit2  | 
张量  | 
[kvNTokens, kvHeadNum, 64]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,key旋转位置编码。  | 
value  | 
张量  | 
[kvNTokens, kvHeadNum, 128]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,value矩阵。  | 
mask  | 
张量  | 
[512, 512]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输入tensor,掩码。  | 
seqLen  | 
张量  | 
[batch]/[2, batch]  | 
int32/uint32  | 
ND  | 
-  | 
是  | 
输入tensor,序列长度。 
  | 
prevOut  | 
张量  | 
[qNTokens, headNum, headSizeV]  | 
float16/bf16  | 
ND  | 
-  | 
否  | 
输入tensor,前次输出。  | 
prevLse  | 
张量  | 
[headNum, qNTokens]  | 
float16  | 
ND  | 
-  | 
否  | 
输入tensor,前次QK^T * tor的结果,先取softmax,exp,sum,最后求log。  | 
headNum  | 
标量  | 
-  | 
int32_t  | 
-  | 
0  | 
是  | 
query头大小。  | 
kvHeadNum  | 
标量  | 
-  | 
int32_t  | 
-  | 
0  | 
是  | 
kv头数量, 该值需要用户根据使用的模型实际情况传入。 
  | 
qkScale  | 
标量  | 
-  | 
float  | 
-  | 
1.0  | 
是  | 
算子tor值, 在Q*K^T后乘。  | 
kernelType  | 
标量  | 
-  | 
int  | 
-  | 
1  | 
是  | 
内核精度类型。 仅支持1:KERNELTYPE_HIGH_PRECISION。  | 
maskType  | 
标量  | 
-  | 
int  | 
-  | 
1  | 
是  | 
mask类型。 
  | 
inputLayout  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
是  | 
数据排布格式。 仅支持0:TYPE_BSND  | 
calcType  | 
标量  | 
-  | 
int  | 
-  | 
0  | 
是  | 
计算类型。 
  | 
output  | 
张量  | 
[qNTokens, headNum, headSizeV]  | 
float16/bf16  | 
ND  | 
-  | 
是  | 
输出tensor。  | 
softmaxLse  | 
张量  | 
[headNum, qNTokens]  | 
float  | 
ND  | 
-  | 
是  | 
softmaxLse输出。  |