将各SP域PA算子的输出的中间结果lse、attention out两个局部结果更新成全局结果。
硬件型号 |
支持情况 |
---|---|
支持 |
|
支持 |
1 2 3 4 5 6 7 8 | struct FaUpdateParam { enum FaUpdateType { DECODE_UPDATE = 0, }; FaUpdateType faUpdateType = DECODE_UPDATE; uint32_t sp = 1; uint8_t rsv[64] = {0}; }; |
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
faUpdateType |
FaUpdateType |
DECODE_UPDATE |
DECODE_UPDATE |
否 |
指定下标需要执行的操作类型,目前只支持取默认值。 |
sp |
uint32_t |
1 |
[1, 8] |
否 |
序列并行的并行度SP。 |
rsv[64] |
uint8_t |
{0} |
[0] |
否 |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
lse |
[sp, batch * seqLen * headNum] |
float |
ND |
各SP域计算的lse。 |
localout |
[sp, batch * seqLen * headNum, head_size] |
float |
ND |
各SP域计算的output。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[batch * seqLen * headNum, head_size] |
float |
ND |
全局的output。 |
head_size取值范围为[8, 512],且必须是8的倍数。