昇腾社区首页
中文
注册

FaUpdateOperation(代码开放)

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

将各SP域PA算子的输出的中间结果lse、attention out两个局部结果更新成全局结果。

定义

1
2
3
4
5
6
7
8
struct FaUpdateParam {
    enum FaUpdateType {
        DECODE_UPDATE = 0, 
    };
    FaUpdateType faUpdateType = DECODE_UPDATE;
    uint32_t sp = 1;
    uint8_t rsv[64] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

faUpdateType

FaUpdateType

DECODE_UPDATE

DECODE_UPDATE

指定下标需要执行的操作类型,目前只支持取默认值。

sp

uint32_t

1

[1, 16]

序列并行的并行度SP。

rsv[64]

uint8_t

{0}

[0]

预留参数。

输入

参数

维度

数据类型

格式

描述

lse

[sp, batch * seqLen * headNum]

float

ND

各SP域计算的lse。

localout

[sp, batch * seqLen * headNum, head_size]

float

ND

各SP域计算的output。

输出

参数

维度

数据类型

格式

描述

output

[batch * seqLen * headNum, head_size]

float

ND

全局的output。

约束说明

head_size取值范围为[8, 512],且必须是8的倍数