RingMLAOperation(代码开放)
产品支持情况
硬件型号 |
是否支持 |
---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
功能说明
基于传统MultiLatentAttention,并使能ring MLA算子的输出的中间结果lse,attention out两个局部结果更新成全局结果,支持更长的序列长度。
定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | struct RingMLAParam { enum CalcType : int { CALC_TYPE_DEFAULT = 0, CALC_TYPE_FISRT_RING, CALC_TYPE_MAX }; enum KernelType : int { KERNELTYPE_DEFAULT = 0, KERNELTYPE_HIGH_PRECISION }; enum MaskType : int { NO_MASK = 0, MASK_TYPE_TRIU, }; CalcType calcType = CalcType::CALC_TYPE_DEFAULT; int32_t headNum = 0; int32_t kvHeadNum = 0; float qkScale = 1; KernelType kernelType = KERNELTYPE_HIGH_PRECISION; MaskType maskType = MASK_TYPE_TRIU; InputLayout inputLayout = TYPE_BSND; uint8_t rsv[64] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
calcType |
CalcType |
CALC_TYPE_DEFAULT |
CALC_TYPE_DEFAULT CALC_TYPE_FISRT_RING |
是 |
计算类型。
|
headNum |
int32_t |
0 |
大于0 |
是 |
query头大小, 需大于0。 |
kvHeadNum |
int32_t |
0 |
大于等于0 |
是 |
kv头数量, 该值需要用户根据使用的模型实际情况传入
|
qkScale |
float |
1 |
- |
是 |
算子tor值, 在Q*K^T后乘。 |
kernelType |
KernelType |
KERNELTYPE_HIGH_PRECISION |
KERNELTYPE_HIGH_PRECISION |
是 |
内核精度类型。KERNELTYPE_HIGH_PRECISION:输入/输出tensor使用float16/bf16,softmax使用float类型。 |
maskType |
MaskType |
MASK_TYPE_TRIU |
NO_MASK MASK_TYPE_TRIU |
是 |
mask类型。 NO_MASK:不使用mask。 MASK_TYPE_TRIU:默认值,上三角mask。 |
inputLayout |
InputLayout |
TYPE_BSND |
TYPE_BSND |
是 |
数据排布格式默认为BSND。 |
rsv[64] |
uint8_t |
{0} |
[0] |
否 |
预留参数。 |
输入
参数 |
维度 |
数据类型 |
格式 |
cpu/npu |
描述 |
使用场景 |
---|---|---|---|---|---|---|
query |
[qNTokens, headNum, 128] |
float16/bf16 |
ND |
npu |
无位置编码query矩阵。 |
基础场景 |
queryRope |
[qNTokens, headNum, 64] |
float16/bf16 |
ND |
npu |
query旋转位置编码分量。 |
基础场景 |
key |
[kvNTokens, kvHeadNum, 128] |
float16/bf16 |
ND |
npu |
无位置编码key矩阵。 |
基础场景 |
keyRope |
[kvNTokens, kvHeadNum, 64] |
float16/bf16 |
ND |
npu |
key旋转位置编码。 |
基础场景 |
value |
[kvNTokens, kvHeadNum, 128] |
float16/bf16 |
ND |
npu |
value矩阵。 |
基础场景 |
mask |
[512, 512] |
float16/bf16 |
ND |
npu |
掩码。 |
基础场景 |
seqLen |
[batch]/[2, batch] |
int32/uint32 |
ND |
cpu |
序列长度。
|
基础场景 |
prevOut |
[qNTokens, headNum, 128] |
float16/bf16 |
ND |
npu |
前次输出。 |
非首卡场景 |
prevLse |
[headNum, qNTokens] |
float |
ND |
npu |
前次QK^T * tor的结果,先取softmax,exp,sum,最后求log。 |
非首卡场景 |
输出
参数 |
维度 |
数据类型 |
格式 |
cpu/npu |
描述 |
使用场景 |
---|---|---|---|---|---|---|
output |
[qNTokens, headNum, headSizeV] |
float16/bf16 |
ND |
npu |
输出。 |
基础场景 |
softmaxLse |
[headNum, qNTokens] |
float |
ND |
npu |
softmaxLse输出。 |
基础场景 |
功能列表
- 首卡场景
- 开启方式:calcType = CALC_TYPE_FISRT_RING
- 区别:无prevLse,prevOut传入,生成softmaxLse输出。
- 非首末卡场景
- 开启方式:calcType = CALC_TYPE_DEFAULT
- 区别:有prevLse,prevOut传入,生成softmaxLse输出。
约束说明
- maskType = MASK_TYPE_TRIU时才使用mask。
- inputLayout仅支持TYPE_BSND。
- 二维seqLen约束:
- qSeqLen为seqLen[0]。
- kvSeqLen为seqLen[1]。
- 对于每个下标i,qSeqLen[i]不可为0。
- 对于每个下标i,kvSeqLen[i] >= qSeqLen[i]或者kvSeqLen[i]为0,但注意kvSeqLen[0]和kvSeqLen[batch - 1]不可为0。