RingMLAOperation(代码开放)
产品支持情况
硬件型号  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能
基于传统MultiLatentAttention,并使能ring MLA算子的输出的中间结果lse,attention out两个局部结果更新成全局结果,支持更长的序列长度。
定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23  | struct RingMLAParam { enum CalcType : int { CALC_TYPE_DEFAULT = 0, CALC_TYPE_FISRT_RING, CALC_TYPE_MAX }; enum KernelType : int { KERNELTYPE_DEFAULT = 0, KERNELTYPE_HIGH_PRECISION }; enum MaskType : int { NO_MASK = 0, MASK_TYPE_TRIU, }; CalcType calcType = CalcType::CALC_TYPE_DEFAULT; int32_t headNum = 0; int32_t kvHeadNum = 0; float qkScale = 1; KernelType kernelType = KERNELTYPE_HIGH_PRECISION; MaskType maskType = MASK_TYPE_TRIU; InputLayout inputLayout = TYPE_BSND; uint8_t rsv[64] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
calcType  | 
CalcType  | 
CALC_TYPE_DEFAULT  | 
CALC_TYPE_DEFAULT CALC_TYPE_FISRT_RING  | 
是  | 
计算类型。 
  | 
headNum  | 
int32_t  | 
0  | 
大于0  | 
是  | 
query头大小, 需大于0。  | 
kvHeadNum  | 
int32_t  | 
0  | 
大于等于0  | 
是  | 
kv头数量, 该值需要用户根据使用的模型实际情况传入 
  | 
qkScale  | 
float  | 
1  | 
-  | 
是  | 
算子tor值, 在Q*K^T后乘。  | 
kernelType  | 
KernelType  | 
KERNELTYPE_HIGH_PRECISION  | 
KERNELTYPE_HIGH_PRECISION  | 
是  | 
内核精度类型。 
  | 
maskType  | 
MaskType  | 
MASK_TYPE_TRIU  | 
NO_MASK MASK_TYPE_TRIU  | 
是  | 
mask类型。 NO_MASK:不使用mask。 MASK_TYPE_TRIU:默认值,上三角mask。  | 
inputLayout  | 
InputLayout  | 
TYPE_BSND  | 
TYPE_BSND  | 
是  | 
数据排布格式默认为BSND。  | 
rsv[64]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
cpu/npu  | 
描述  | 
使用场景  | 
|---|---|---|---|---|---|---|
query  | 
[qNTokens, headNum, 128]  | 
float16/bf16  | 
ND  | 
npu  | 
无位置编码query矩阵。  | 
基础场景  | 
queryRope  | 
[qNTokens, headNum, 64]  | 
float16/bf16  | 
ND  | 
npu  | 
query旋转位置编码分量。  | 
基础场景  | 
key  | 
[kvNTokens, kvHeadNum, 128]  | 
float16/bf16  | 
ND  | 
npu  | 
无位置编码key矩阵。  | 
基础场景  | 
keyRope  | 
[kvNTokens, kvHeadNum, 64]  | 
float16/bf16  | 
ND  | 
npu  | 
key旋转位置编码。  | 
基础场景  | 
value  | 
[kvNTokens, kvHeadNum, 128]  | 
float16/bf16  | 
ND  | 
npu  | 
value矩阵。  | 
基础场景  | 
mask  | 
[512, 512]  | 
float16/bf16  | 
ND  | 
npu  | 
掩码。  | 
基础场景  | 
seqLen  | 
[batch]/[2, batch]  | 
int32/uint32  | 
ND  | 
cpu  | 
序列长度。 
  | 
基础场景  | 
prevOut  | 
[qNTokens, headNum, 128]  | 
float16/bf16  | 
ND  | 
npu  | 
前次输出。  | 
非首卡场景  | 
prevLse  | 
[headNum, qNTokens]  | 
float  | 
ND  | 
npu  | 
前次QK^T * tor的结果,先取softmax,exp,sum,最后求log。  | 
非首卡场景  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
cpu/npu  | 
描述  | 
使用场景  | 
|---|---|---|---|---|---|---|
output  | 
[qNTokens, headNum, headSizeV]  | 
float16/bf16  | 
ND  | 
npu  | 
输出。  | 
基础场景  | 
softmaxLse  | 
[headNum, qNTokens]  | 
float  | 
ND  | 
npu  | 
softmaxLse输出。  | 
基础场景  | 
功能列表
- 首卡场景
- 开启方式:calcType = CALC_TYPE_FISRT_RING
 - 区别:无prevLse,prevOut传入,生成softmaxLse输出。
 
 - 非首末卡场景
- 开启方式:calcType = CALC_TYPE_DEFAULT
 - 区别:有prevLse,prevOut传入,生成softmaxLse输出。
 
 
规格约束
- maskType = MASK_TYPE_TRIU时才使用mask。
 - inputLayout仅支持TYPE_BSND。
 - 二维seqLen约束:
- qSeqLen为seqLen[0]。
 - kvSeqLen为seqLen[1]。
 - 对于每个下标i,qSeqLen[i]不可为0。
 - 对于每个下标i,kvSeqLen[i] >= qSeqLen[i]或者kvSeqLen[i]为0,但注意kvSeqLen[0]和kvSeqLen[batch - 1]不可为0。