FaUpdateOperation(代码开放)
产品支持情况
硬件型号  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能
将各SP域PA算子的输出的中间结果lse、attention out两个局部结果更新成全局结果。
定义
1 2 3 4 5 6 7 8  | struct FaUpdateParam { enum FaUpdateType { DECODE_UPDATE = 0, }; FaUpdateType faUpdateType = DECODE_UPDATE; uint32_t sp = 1; uint8_t rsv[64] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
faUpdateType  | 
FaUpdateType  | 
DECODE_UPDATE  | 
DECODE_UPDATE  | 
否  | 
指定下标需要执行的操作类型,目前只支持取默认值。  | 
sp  | 
uint32_t  | 
1  | 
[1, 16]  | 
否  | 
序列并行的并行度SP。  | 
rsv[64]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
lse  | 
[sp, batch * seqLen * headNum]  | 
float  | 
ND  | 
各SP域计算的lse。  | 
localout  | 
[sp, batch * seqLen * headNum, head_size]  | 
float  | 
ND  | 
各SP域计算的output。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
output  | 
[batch * seqLen * headNum, head_size]  | 
float  | 
ND  | 
全局的output。  | 
规格约束
head_size取值范围为[8, 512],且必须是8的倍数。