GatherPreRmsNormOperation(代码开放)
产品支持情况
硬件型号  | 
是否支持  | 
|---|---|
√  | 
|
√  | 
|
x  | 
|
x  | 
|
x  | 
功能
首先对ResIn进行Gather索引操作,然后与X相加,最后进行RmsNorm计算。
计算公式
计算图

定义
1 2 3 4  | struct GatherPreRmsNormParam { float epsilon = 1e-5; uint8_t rsv[28] = {0}; };  | 
参数列表
成员名称  | 
类型  | 
默认值  | 
取值范围  | 
是否必选  | 
描述  | 
|---|---|---|---|---|---|
epsilon  | 
float  | 
1e-5  | 
(2e-38, 1e-5]  | 
是  | 
归一化时加在分母上防止除0。  | 
rsv[28]  | 
uint8_t  | 
{0}  | 
[0]  | 
否  | 
预留参数。  | 
输入
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
x  | 
[dim_m, dim_n]  | 
float16/bf16  | 
ND  | 
2维,末轴维度需要32字节对齐(16的倍数)且不大于7680。  | 
resIn  | 
[dim_res, dim_n]  | 
float16/bf16  | 
ND  | 
2维,dim_m、dim_res可能不相等,与x的数据类型一样。  | 
indices  | 
[dim_m]  | 
int32  | 
ND  | 
1维,元素取值范围为[0, dim_res),其中dim_res为ResIn的第一个维度值。  | 
gamma  | 
[1, dim_n]或[dim_n]  | 
float16/bf16  | 
ND  | 
1维或2维,与x的数据类型一样。  | 
输出
参数  | 
维度  | 
数据类型  | 
格式  | 
描述  | 
|---|---|---|---|---|
y  | 
[dim_m, dim_n]  | 
float  | 
ND  | 
数据类型、维度数和维度值与x一样。  | 
ResOut  | 
[dim_m, dim_n]  | 
float16/bf16  | 
ND  | 
维度与x一致。  | 


