RingMLAOperation
定义
atb::Status AtbRingMLAGetWorkspaceSize(const aclTensor *querySplit1, const aclTensor *querySplit2, const aclTensor *keySplit1, const aclTensor *keySplit2, const aclTensor *value, const aclTensor *mask, const aclTensor *seqLen, const aclTensor *prevOut, const aclTensor *prevLse, int32_t headNum, int32_t kvHeadNum, float qkScale, int kernelType, int maskType, int inputLayout, int calcType, aclTensor *output, aclTensor *softmaxLse, uint64_t *workspaceSize, atb::Operation **op, atb::Context *context); atb::Status AtbRingMLA(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);
AtbRingMLAGetWorkspaceSize成员
参数 |
标量/张量 |
维度 |
数据类型 |
格式 |
默认值 |
是否必选 |
描述 |
---|---|---|---|---|---|---|---|
querySplit1 |
张量 |
[qNTokens, headNum, 128] |
float16/bf16 |
ND |
- |
是 |
输入tensor,无位置编码query矩阵。 |
querySplit2 |
张量 |
[qNTokens, headNum, 64] |
float16/bf16 |
ND |
- |
是 |
输入tensor,query旋转位置编码分量。 |
keySplit1 |
张量 |
[kvNTokens, kvHeadNum, 128] |
float16/bf16 |
ND |
- |
是 |
输入tensor,无位置编码key矩阵。 |
keySplit2 |
张量 |
[kvNTokens, kvHeadNum, 64] |
float16/bf16 |
ND |
- |
是 |
输入tensor,key旋转位置编码。 |
value |
张量 |
[kvNTokens, kvHeadNum, 128] |
float16/bf16 |
ND |
- |
是 |
输入tensor,value矩阵。 |
mask |
张量 |
[512, 512] |
float16/bf16 |
ND |
- |
是 |
输入tensor,掩码。 |
seqLen |
张量 |
[batch]/[2, batch] |
int32/uint32 |
ND |
- |
是 |
输入tensor,序列长度。
|
prevOut |
张量 |
[qNTokens, headNum, headSizeV] |
float16/bf16 |
ND |
- |
否 |
输入tensor,前次输出。 |
prevLse |
张量 |
[headNum, qNTokens] |
float16 |
ND |
- |
否 |
输入tensor,前次QK^T * tor的结果,先取softmax,exp,sum,最后求log。 |
headNum |
标量 |
- |
int32_t |
- |
0 |
是 |
query头大小。 |
kvHeadNum |
标量 |
- |
int32_t |
- |
0 |
是 |
kv头数量, 该值需要用户根据使用的模型实际情况传入。
|
qkScale |
标量 |
- |
float |
- |
1.0 |
是 |
算子tor值, 在Q*K^T后乘。 |
kernelType |
标量 |
- |
int |
- |
1 |
是 |
内核精度类型。 仅支持1:KERNELTYPE_HIGH_PRECISION。 |
maskType |
标量 |
- |
int |
- |
1 |
是 |
mask类型。
|
inputLayout |
标量 |
- |
int |
- |
0 |
是 |
数据排布格式。 仅支持0:TYPE_BSND |
calcType |
标量 |
- |
int |
- |
0 |
是 |
计算类型。
|
output |
张量 |
[qNTokens, headNum, headSizeV] |
float16/bf16 |
ND |
- |
是 |
输出tensor。 |
softmaxLse |
张量 |
[headNum, qNTokens] |
float |
ND |
- |
是 |
softmaxLse输出。 |