FaUpdateOperation(代码开放)
产品支持情况
硬件型号 |
是否支持 |
---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
功能说明
将各SP域PA算子的输出的中间结果lse、attention out两个局部结果更新成全局结果。
定义
1 2 3 4 5 6 7 8 | struct FaUpdateParam { enum FaUpdateType { DECODE_UPDATE = 0, }; FaUpdateType faUpdateType = DECODE_UPDATE; uint32_t sp = 1; uint8_t rsv[64] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
faUpdateType |
FaUpdateType |
DECODE_UPDATE |
DECODE_UPDATE |
否 |
指定下标需要执行的操作类型,目前只支持取默认值。 |
sp |
uint32_t |
1 |
[1, 16] |
否 |
序列并行的并行度SP。 |
rsv[64] |
uint8_t |
{0} |
[0] |
否 |
预留参数。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
lse |
[sp, batch * seqLen * headNum] |
float |
ND |
各SP域计算的lse。 |
localout |
[sp, batch * seqLen * headNum, head_size] |
float |
ND |
各SP域计算的output。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[batch * seqLen * headNum, head_size] |
float |
ND |
全局的output。 |
约束说明
head_size取值范围为[8, 512],且必须是8的倍数。