FaUpdateOperation

功能

将各SP域PA算子的输出的中间结果lse、attention out两个局部结果更新成全局结果。

硬件支持情况

硬件型号

支持情况

Atlas 800I A2 推理产品/Atlas A2 训练系列产品

支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

支持

定义

1
2
3
4
5
6
7
8
struct FaUpdateParam {
    enum FaUpdateType {
        DECODE_UPDATE = 0, 
    };
    FaUpdateType faUpdateType = DECODE_UPDATE;
    uint32_t sp = 1;
    uint8_t rsv[64] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

faUpdateType

FaUpdateType

DECODE_UPDATE

DECODE_UPDATE

指定下标需要执行的操作类型,目前只支持取默认值。

sp

uint32_t

1

[1, 8]

序列并行的并行度SP。

rsv[64]

uint8_t

{0}

[0]

预留参数。

输入

参数

维度

数据类型

格式

描述

lse

[sp, batch * seqLen * headNum]

float

ND

各SP域计算的lse。

localout

[sp, batch * seqLen * headNum, head_size]

float

ND

各SP域计算的output。

输出

参数

维度

数据类型

格式

描述

output

[batch * seqLen * headNum, head_size]

float

ND

全局的output。

规格约束

head_size取值范围为[8, 512],且必须是8的倍数