昇腾社区首页
中文
注册

RingMLAOperation(代码开放)

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

基于传统MultiLatentAttention,并使能ring MLA算子的输出的中间结果lse,attention out两个局部结果更新成全局结果,支持更长的序列长度。

定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
struct RingMLAParam {
    enum CalcType : int {
        CALC_TYPE_DEFAULT = 0,
        CALC_TYPE_FISRT_RING,  
        CALC_TYPE_MAX
    };
    enum KernelType : int {
        KERNELTYPE_DEFAULT = 0,   
        KERNELTYPE_HIGH_PRECISION 
    };
    enum MaskType : int {
        NO_MASK = 0,    
        MASK_TYPE_TRIU,
    };
    CalcType calcType = CalcType::CALC_TYPE_DEFAULT;
    int32_t headNum = 0;
    int32_t kvHeadNum = 0;
    float qkScale = 1;
    KernelType kernelType = KERNELTYPE_HIGH_PRECISION; 
    MaskType maskType = MASK_TYPE_TRIU;
    InputLayout inputLayout = TYPE_BSND;
    uint8_t rsv[64] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

calcType

CalcType

CALC_TYPE_DEFAULT

CALC_TYPE_DEFAULT

CALC_TYPE_FISRT_RING

计算类型。

  • CALC_TYPE_DEFAULT:默认,非首末卡场景,有prevLse,prevOut传入,生成softmaxLse输出 。
  • CALC_TYPE_FISRT_RING:首卡场景,无prevLse, prevOut传入,生成softmaxLse输出。

headNum

int32_t

0

大于0

query头大小, 需大于0。

kvHeadNum

int32_t

0

大于等于0

kv头数量, 该值需要用户根据使用的模型实际情况传入

  • kvHeadNum = 0时,key的kHeadNum,value的vHeadNum与query的headNum一致,均为headNum的数值。
  • kvHeadNum != 0时,key的kHeadNum, value的vHeadNum与kvHeadNum值相同。

qkScale

float

1

-

算子tor值, 在Q*K^T后乘。

kernelType

KernelType

KERNELTYPE_HIGH_PRECISION

KERNELTYPE_HIGH_PRECISION

内核精度类型。KERNELTYPE_HIGH_PRECISION:输入/输出tensor使用float16/bf16,softmax使用float类型。

maskType

MaskType

MASK_TYPE_TRIU

NO_MASK

MASK_TYPE_TRIU

mask类型。

NO_MASK:不使用mask。

MASK_TYPE_TRIU:默认值,上三角mask。

inputLayout

InputLayout

TYPE_BSND

TYPE_BSND

数据排布格式默认为BSND。

rsv[64]

uint8_t

{0}

[0]

预留参数。

输入

参数

维度

数据类型

格式

cpu/npu

描述

使用场景

query

[qNTokens, headNum, 128]

float16/bf16

ND

npu

无位置编码query矩阵。

基础场景

queryRope

[qNTokens, headNum, 64]

float16/bf16

ND

npu

query旋转位置编码分量。

基础场景

key

[kvNTokens, kvHeadNum, 128]

float16/bf16

ND

npu

无位置编码key矩阵。

基础场景

keyRope

[kvNTokens, kvHeadNum, 64]

float16/bf16

ND

npu

key旋转位置编码。

基础场景

value

[kvNTokens, kvHeadNum, 128]

float16/bf16

ND

npu

value矩阵。

基础场景

mask

[512, 512]

float16/bf16

ND

npu

掩码。

基础场景

seqLen

[batch]/[2, batch]

int32/uint32

ND

cpu

序列长度。

  • 若shape为[batch] ,代表每个batch的序列长度,query,key,value相同。
  • 若shape为[2,batch],seqlen[0]代表query的序列长度,seqlen[1]代表key,value的序列长度。

基础场景

prevOut

[qNTokens, headNum, 128]

float16/bf16

ND

npu

前次输出。

非首卡场景

prevLse

[headNum, qNTokens]

float

ND

npu

前次QK^T * tor的结果,先取softmax,exp,sum,最后求log。

非首卡场景

输出

参数

维度

数据类型

格式

cpu/npu

描述

使用场景

output

[qNTokens, headNum, headSizeV]

float16/bf16

ND

npu

输出。

基础场景

softmaxLse

[headNum, qNTokens]

float

ND

npu

softmaxLse输出。

基础场景

功能列表

  • 首卡场景
    • 开启方式:calcType = CALC_TYPE_FISRT_RING
    • 区别:无prevLse,prevOut传入,生成softmaxLse输出。
  • 非首末卡场景
    • 开启方式:calcType = CALC_TYPE_DEFAULT
    • 区别:有prevLse,prevOut传入,生成softmaxLse输出。

约束说明

  • maskType = MASK_TYPE_TRIU时才使用mask。
  • inputLayout仅支持TYPE_BSND。
  • 二维seqLen约束:
    • qSeqLen为seqLen[0]。
    • kvSeqLen为seqLen[1]。
    • 对于每个下标i,qSeqLen[i]不可为0。
    • 对于每个下标i,kvSeqLen[i] >= qSeqLen[i]或者kvSeqLen[i]为0,但注意kvSeqLen[0]和kvSeqLen[batch - 1]不可为0。