昇腾社区首页
中文
注册

RingMLAOperation

定义

atb::Status AtbRingMLAGetWorkspaceSize(const aclTensor *querySplit1, const aclTensor *querySplit2,
                                       const aclTensor *keySplit1, const aclTensor *keySplit2, const aclTensor *value,
                                       const aclTensor *mask, const aclTensor *seqLen, const aclTensor *prevOut,
                                       const aclTensor *prevLse, int32_t headNum, int32_t kvHeadNum, float qkScale,
                                       int kernelType, int maskType, int inputLayout, int calcType, aclTensor *output,
                                       aclTensor *softmaxLse, uint64_t *workspaceSize, atb::Operation **op,
                                       atb::Context *context);
atb::Status AtbRingMLA(void *workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);

AtbRingMLAGetWorkspaceSize成员

参数

标量/张量

维度

数据类型

格式

默认值

是否必选

描述

querySplit1

张量

[qNTokens, headNum, 128]

float16/bf16

ND

-

输入tensor,无位置编码query矩阵。

querySplit2

张量

[qNTokens, headNum, 64]

float16/bf16

ND

-

输入tensor,query旋转位置编码分量。

keySplit1

张量

[kvNTokens, kvHeadNum, 128]

float16/bf16

ND

-

输入tensor,无位置编码key矩阵。

keySplit2

张量

[kvNTokens, kvHeadNum, 64]

float16/bf16

ND

-

输入tensor,key旋转位置编码。

value

张量

[kvNTokens, kvHeadNum, 128]

float16/bf16

ND

-

输入tensor,value矩阵。

mask

张量

[512, 512]

float16/bf16

ND

-

输入tensor,掩码。

seqLen

张量

[batch]/[2, batch]

int32/uint32

ND

-

输入tensor,序列长度。

  • 若shape为[batch] ,代表每个batch的序列长度,query,cacheK,cacheV相同。
  • 若shape为[2,batch],seqlen[0]代表query的序列长度,seqlen[1]代表cacheK,cacheV的序列长度。

prevOut

张量

[qNTokens, headNum, headSizeV]

float16/bf16

ND

-

输入tensor,前次输出。

prevLse

张量

[headNum, qNTokens]

float16

ND

-

输入tensor,前次QK^T * tor的结果,先取softmax,exp,sum,最后求log。

headNum

标量

-

int32_t

-

0

query头大小。

kvHeadNum

标量

-

int32_t

-

0

kv头数量, 该值需要用户根据使用的模型实际情况传入。

  • kvHeadNum = 0时,keyCache的k_head_num,valueCache的v_head_num与query的num_heads一致,均为num_heads的数值。
  • kvHeadNum != 0时,keyCache的k_head_num, valueCache的v_head_num与kvHeadNum值相同。

qkScale

标量

-

float

-

1.0

算子tor值, 在Q*K^T后乘。

kernelType

标量

-

int

-

1

内核精度类型。

仅支持1:KERNELTYPE_HIGH_PRECISION。

maskType

标量

-

int

-

1

mask类型。

  • 0:NO_MASK
  • 1:MASK_TYPE_TRIU

inputLayout

标量

-

int

-

0

数据排布格式。

仅支持0:TYPE_BSND

calcType

标量

-

int

-

0

计算类型。

  • CALC_TYPE_DEFAULT = 0:默认,非首末卡场景,有prev_lse, prev_o传入,生成softmaxLse输出。目前仅支持默认值。
  • CALC_TYPE_FISRT_RING: 首卡场景,无prev_lse, prev_o传入,生成softmaxLse输出。

output

张量

[qNTokens, headNum, headSizeV]

float16/bf16

ND

-

输出tensor。

softmaxLse

张量

[headNum, qNTokens]

float

ND

-

softmaxLse输出。

原始接口

请参见RingMLAOperation(代码开放)