RmsNormBackwardOperation
功能
rmsnorm的反向计算。
定义
1 2 3 | struct RmsNormBackwardParam { uint8_t rsv[8] = {0}; }; |
参数列表
成员名称 |
类型 |
默认值 |
描述 |
---|---|---|---|
rsv[8] |
uint8_t |
{0} |
预留参数。 |
输入
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
dy |
[-1,…,-1] |
float16/float/bf16 |
ND |
输入梯度。维度与x相同。数据类型与x一致。 |
x |
[-1,…,-1] |
float16/float/bf16 |
ND |
正向计算输入。 |
rstd |
[-1,…,-1] |
float |
ND |
正向计算中间结果。 |
gamma |
[-1,…,-1] |
float16/float/bf16 |
ND |
数据类型与x一致。维度数需要大于0,并小于x的维度数,gamma的维度从最后一维向前,每一维都需要和x保持一致。 |
输出
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
dx |
[-1,…,-1] |
float16/float/bf16 |
ND |
正向输入x的梯度。维度与x一致。数据类型与x一致。 |
dgamma |
[-1,…,-1] |
float |
ND |
正向输入gamma的梯度。维度与gamma一致。 |
规格约束
当前仅支持