GatherPreRmsNormOperation

功能

首先对ResIn进行Gather索引操作,然后与X相加,最后进行RmsNorm计算。

计算公式

计算图

硬件支持情况

硬件型号

支持情况

Atlas 800I A2 推理产品

支持

定义

1
2
3
4
struct GatherPreRmsNormParam {
    float epsilon = 1e-5;
    uint8_t rsv[28] = {0};
};

参数列表

成员名称

类型

默认值

取值范围

是否必选

描述

epsilon

float

1e-5

(2e-38, 1e-5]

归一化时加在分母上防止除零。

rsv[28]

uint8_t

{0}

[0]

预留参数。

输入

参数

维度

数据类型

格式

描述

x

[dim_m, dim_n]

float16/bf16

ND

2维,末轴维度需要32字节对齐(16的倍数)且不大于7680。

resIn

[dim_res, dim_n]

float16/bf16

ND

2维,dim_mdim_res可能不相等,与x的数据类型一样。

indices

[dim_m]

int32

ND

1维,元素取值范围为[0, dim_res),其中dim_res为ResIn的第一个维度值。

gamma

[1, dim_n]或[dim_n]

float16/bf16

ND

1维或2维,与x的数据类型一样。

输出

参数

维度

数据类型

格式

描述

y

[dim_m, dim_n]

float

ND

数据类型、维度数和维度值与x一样。

ResOut

[dim_m, dim_n]

float16/bf16

ND

维度与X一致