昇腾社区首页
中文
注册

MLA Prefill

定义

atb::Status AtbMLAPreFillGetWorkspaceSize(const aclTensor *q, const aclTensor *qRope, const aclTensor *k,
                                          const aclTensor *kRope, const aclTensor *v, const aclTensor *qSeqLen, const aclTensor *kvSeqLen,
                                          const aclTensor *mask, int32_t headNum, float qkScale, int32_t kvHeadNum,
                                          int maskType, uint8_t cacheMode, aclTensor *attenOut,
                                          uint64_t *workspaceSize, atb::Operation **op, atb::Context *context);
atb::Status AtbMLAPreFill(void* workspace, uint64_t workspaceSize, atb::Operation *op, atb::Context *context);

AtbMLAPreFillGetWorkspaceSize成员

参数

标量/张量

维度

数据类型

格式

默认值

是否必选

描述

Q

张量

[nTokens, qhead * embeddimV]

[nTokens, qhead, embeddimV]

float16/bf16

ND

-

query

QRope

张量

[nTokens, qhead * 64]

[nTokens, qhead, 64]

float16/bf16

ND

-

旋转位置编码query

K

张量

[batch, max_seq, kv_head * embeddimV]

[batch*seq, kv_head, embeddimV]

float16/bf16

ND

-

key

kRope

张量

[batch, max_seq, kv_head * 64]

[batch*seq, kv_head, 64]

float16/bf16

ND

-

旋转位置编码key

V

张量

[batch, max_seq, kv_head * embeddimV]

[batch*seq, kv_head, embeddimV]

float16/bf16

ND

-

value

qSeqLen

张量

[batch]

int32

ND

-

输入tensor

kvSeqLen

张量

[batch]

int32

ND

-

输入tensor

Mask

张量

None

[512,512]

float16/bf16

ND

-

注意力掩码,UNDEFINED 或 MASK_TYPE_CAUSAL_MASK 时可以置为nullptr,MASK_TYPE_MASK_FREE时输入tensor。

headNum

标量

-

int32_t

-

0

query头数量。

qkScale

标量

-

float

-

1.0

Q*K^T后乘以的缩放系数。

kvHeadNum

标量

-

int32_t

-

0

kv头数量。

maskType

标量

-

int

-

0

mask类型。

UNDEFINED:默认值,全0的mask;

MASK_TYPE_SPEC:qseqlen > 1时的mask。

MASK_TYPE_MASK_FREE:maskfree功能。

calcType

标量

-

int

-

0

计算类型。

  • CALC_TYPE_UNDEFINED = 0:默认值。
  • CALC_TYPE_SPEC:支持传入大于1的qseqlen。
  • CALC_TYPE_RING:ringAttention。
  • CALC_TYPE_PREFILL:prefill 全量。

cacheMode

标量

-

int

-

0

  • KVCACHE = 0:拼接cache。
  • KROPE_CTKV:分离cache。
  • INT8_NZCACHE:高性能分离cache。
  • NZCACHE:非量化NZcache。

attenOut

张量

[nTokens, qhead * embeddimV]

[ntokens, kvhead, embedv]

float16/bf16

ND

-

attention输出