昇腾社区首页
中文
注册

MultiLatentAttentionOperation

功能

MLA场景下的paged attention,使用分页管理的kvcache计算attention score,额外支持分离qnope/qrope、ctkv/krope的输入。

定义

struct MultiLatentAttentionParam {
    int32_t headNum = 0;
    float qkScale = 1.0;
    int32_t kvHeadNum = 0;
    enum MaskType : int {
        UNDEFINED = 0,      
        MASK_TYPE_SPEC,      
        MASK_TYPE_MASK_FREE, 
    };
    MaskType maskType = UNDEFINED;
    enum CalcType : int {
        CALC_TYPE_UNDEFINED = 0,
        CALC_TYPE_SPEC,          
        CALC_TYPE_RING,         
    };
    CalcType calcType = CALC_TYPE_UNDEFINED;
    enum CacheMode : uint8_t {
        KVCACHE = 0,  
        KROPE_CTKV,   
        INT8_NZCACHE, 
        NZCACHE,     
    };
    CacheMode cacheMode = KVCACHE;
    uint8_t rsv[43] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

headNum

int32_t

0

{8,16,32,64,128}

query头数量。

qkScale

float

1.0

(0,1]

Q*K^T后乘以的缩放系数

kvHeadNum

int32_t

0

[1]

kv头数量

maskType

MaskType

UNDEFINED

[0,2]

mask类型。

calcType

CalcType

CALC_TYPE_UNDEFINED

[0,2]

计算类型。

cacheMode

CacheMode

KVCACHE

[0,3]

输入query和kcache的类型。

rsv[43]

uint8_t

0

[0]

预留字段

上表中类型为自定义类型的,其描述如下:

  • maskType:表示mask类型,其具体取值如下。
    • UNDEFINED:无mask。
    • MASK_TYPE_SPEC:并行解码mask,与calcType = CALC_TYPE_SPEC一起使用。
    • MASK_TYPE_MASK_FREE:传入固定shape的mask,与calcType = CALC_TYPE_SPEC一起使用。计算时根据尾部长度从传入mask中获取真实的mask。mask为倒三角形式,第一行全部为-inf,最后一行全部为0。具体形状如下:
      def generate_mask_free(q_len):
              :param q_len: q_len
              :return: constructed mask, e.g. when q_len=2, returned as
              [[-10000.0 -10000.0 -10000.0 ... -10000.0],
               [0        -10000.0 -10000.0 ... -10000.0],
               [0               0 -10000.0 ... -10000.0],
               ...
               [0               0        0 ... -10000.0],
               [0               0        0 ...        0]]
              """
              mask_free = np.full((125 + 2 * q_len, 128), -10000.0)
              mask_free = np.triu(mask_free, 2 - q_len)
              return mask_free
  • calcType:表示计算类型,其具体取值如下。
    • CALC_TYPE_UNDEFINED:默认的decoder场景。
    • CALC_TYPE_SPEC:支持qseqlen大于1,并行解码功能,MTP使用该场景。
    • CALC_TYPE_RING:ringAttention,额外输出Ise。
  • cacheMode:表示输入query和kcache的类型,其具体取值如下。

    与MLAPO大融合预处理算子对应使用。

    • KVCACHE:输入只有一个q和一个合并kvCache,为原PagedAttention下的MLA场景,目前暂未接入。
    • KROPE_CTKV:输入的q拆分为qNope和qRope,输入的kcache拆分为ctKV和kRope,对应之前的默认场景。
    • INT8_NZCACHE:高性能cache,在KROPE_CTKV的基础上:krope和ctkv转为NZ格式输出,ctkv和qnope经过per_head静态对称量化为int8类型。
    • NZCACHE:在KROPE_CTKV的基础上krope和ctkv转为NZ格式输出。

输入

参数

维度

数据类型

格式

cpu/npu

描述

qNope

[num_tokens, num_heads, 512]

float16/bf16/int8

ND

NPU

无位置编码query。

qRope

[num_tokens, num_heads, 64]

float16/bf16

ND

NPU

旋转位置编码query。

ctKV

[num_blocks , block_size, kv_heads, 512]

  • cacheMode为2:

    [blockNum, kv_heads*512/32,block_size, 32]

  • cacheMode为3:

    [blockNum, kv_heads*512/16,block_size, 16]

float16/bf16/int8

ND/NZ

NPU

无位置编码ctkv。

  • cacheMode为2:int8,NZ。
  • cacheMode为3:float16/bf16,NZ。

kRope

[num_blocks , block_size, kv_heads, 64]

cacheMode为2或3:

[blockNum, kv_heads*64 / 16 ,block_size, 16]

float16/bf16

ND/NZ

NPU

旋转位置编码k。

cacheMode为2或3:NZ。

block_tables

[batch, max_num_blocks_per_query]

int32

ND

NPU

每个query的kvcache的block映射表。

contextLens

[batch]

int32

ND

CPU

每个query对应的上下文长度,kseqlen。

mask

  • MASK_TYPE_SPEC:

    [num_tokens(合轴), max_seq_len]

  • MASK_TYPE_MASK_FREE:

    [125 + 2 * qseqlen, 128]

float16/bf16

ND

NPU

注意力掩码,maskType不为0时传入。

qseqlen

[batch]

int32

ND

CPU

calcType为1时传入,每个batch对应的seqLen,取值范围为[1,4]。

qkDescale

[num_heads]

float

ND

NPU

cacheMode为2时传入。

pvDescale

[num_heads]

float

ND

NPU

cacheMode为2时传入。

输出

参数

维度

数据类型

格式

cpu/npu

描述

attenOut

[num_tokens, num_heads, head_size_vo]

float16/bf16

ND

NPU

attention输出。

Ise

[num_tokens, num_heads, 1]

float16/bf16

ND

NPU

Ise输出,只有calcType为CALC_TYPE_RING时需要此tensor。

只支持maskType为UNDEFINED,cacheMode为KROPE_CTKV场景。

规格约束

  • block_size <= 128,建议为128。
  • batch <= 8192
  • cacheMode为INT8_NZCACHE时,不支持num_heads = 128