首先对ResIn进行Gather索引操作,然后与X相加,最后进行RmsNorm计算。
硬件型号 |
支持情况 |
---|---|
支持 |
1 2 3 4 | struct GatherPreRmsNormParam { float epsilon = 1e-5; uint8_t rsv[28] = {0}; }; |
成员名称 |
类型 |
默认值 |
取值范围 |
是否必选 |
描述 |
---|---|---|---|---|---|
epsilon |
float |
1e-5 |
(2e-38, 1e-5] |
是 |
归一化时加在分母上防止除零。 |
rsv[28] |
uint8_t |
{0} |
[0] |
否 |
预留参数。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[dim_m, dim_n] |
float16/bf16 |
ND |
2维,末轴维度需要32字节对齐(16的倍数)且不大于7680。 |
resIn |
[dim_res, dim_n] |
float16/bf16 |
ND |
2维,dim_m、dim_res可能不相等,与x的数据类型一样。 |
indices |
[dim_m] |
int32 |
ND |
1维,元素取值范围为[0, dim_res),其中dim_res为ResIn的第一个维度值。 |
gamma |
[1, dim_n]或[dim_n] |
float16/bf16 |
ND |
1维或2维,与x的数据类型一样。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
y |
[dim_m, dim_n] |
float |
ND |
数据类型、维度数和维度值与x一样。 |
ResOut |
[dim_m, dim_n] |
float16/bf16 |
ND |
维度与X一致。 |