昇腾社区首页
中文
注册

MLA

定义

atb::Status AtbMLAGetWorkspaceSize(const aclTensor *qNope, const aclTensor *qRope, const aclTensor *ctKV,
                                   const aclTensor *kRope, const aclTensor *blockTables, const aclTensor *contextLens,
                                   const aclTensor *mask, const aclTensor *qSeqLen, const aclTensor *qkDescale,
                                   const aclTensor *pvDescale, int32_t headNum, float qkScale, int32_t kvHeadNum,
                                   int maskType, int calcType, uint8_t cacheMode, aclTensor *attenOut, aclTensor *ise,
                                   uint64_t *workspaceSize, atb::Operation **op, atb::Context *context);
atb::Status AtbMLA(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);

AtbMLAGetWorkspaceSize成员

参数

标量/张量

维度

数据类型

格式

默认值

是否必选

描述

qNope

张量

[num_tokens, num_heads, 512]

float16/bf16/int8

ND

-

无位置编码query。

qRope

张量

[num_tokens, num_heads, 64]

float16/bf16

ND

-

旋转位置编码query。

ctKV

张量

  • [num_blocks , block_size, kv_heads, 512]
  • cacheMode为2:

    [blockNum, kv_heads*512/32,block_size, 32]

  • cacheMode为3:

    [blockNum, kv_heads*512/16,block_size, 16]

float16/bf16/int8

ND/NZ

-

无位置编码ctkv。

  • cacheMode为2:int8,NZ
  • cacheMode为3:float16/bf16,NZ

kRope

张量

  • [num_blocks , block_size, kv_heads, 64]
  • cacheMode为2或3:

    [blockNum, kv_heads*64 / 16 ,block_size, 16]

float16/bf16

ND/NZ

-

旋转位置编码k。

cacheMode为2或3:NZ

blockTables

张量

[batch, max_num_blocks_per_query]

int32

ND

-

每个query的kvcache的block映射表。

contextLens

张量

[batch]

int32

ND

-

每个query对应的上下文长度,kseqlen。host侧tensor。

mask

张量

  • MASK_TYPE_SPEC:

    [num_tokens(合轴), max_seq_len]

  • MASK_TYPE_MASK_FREE:

    [125 + 2 * qseqlen, 128]

float16/bf16

ND

-

注意力掩码,maskType为0时传入nullptr。

qseqlen

张量

[batch]

int32

ND

-

calcType为1时传入,每个batch对应的seqLen。

qkDescale

张量

[num_heads]

float

ND

-

cacheMode为2时传入。

pvDescale

张量

[num_heads]

float

ND

-

cacheMode为2时传入。

headNum

标量

-

int32_t

-

0

query头数量。

qkScale

标量

-

float

-

1.0

Q*K^T后乘以的缩放系数。

kvHeadNum

标量

-

int32_t

-

0

kv头数量。

maskType

标量

-

int

-

0

mask类型。

  • UNDEFINED:默认值,全0的mask。
  • MASK_TYPE_SPEC: qseqlen > 1时的mask。
  • MASK_TYPE_MASK_FREE:maskfree功能。

calcType

标量

-

int

-

0

计算类型。

  • CALC_TYPE_UNDEFINED = 0:默认值。
  • CALC_TYPE_SPEC:支持传入大于1的qseqlen。
  • CALC_TYPE_RING: ringAttention。
  • CALC_TYPE_PREFILL,:prefill全量。

cacheMode

标量

-

int

-

0

  • KVCACHE = 0:拼接cache。
  • KROPE_CTKV:分离cache。
  • INT8_NZCACHE:高性能分离cache。
  • NZCACHE:非量化NZcache。

attenOut

张量

[num_tokens, num_heads, 512]

float16/bf16

ND

-

attention输出。

lse

张量

[num_tokens, num_heads, 1]

float16/bf16

ND

-

LSE输出,calcType不为CALC_TYPE_RING时可以传入nullptr。

原始接口

请参见MLA输入输出