RmsNormBackwardOperation

功能

rmsnorm 的反向计算。

定义

struct RmsNormBackwardParam {};

输入

参数

维度

数据类型

格式

描述

dy

[-1,…,-1]

float16/float/bfloat16

ND

输入梯度。维度与x相同。

x

[-1,…,-1]

float16/float/bfloat16

ND

正向计算输入。

rstd

[-1,…,-1]

float

ND

正向计算中间结果。

gamma

[-1,…,-1]

float16/float/bfloat16

ND

维度数不大于x的维度数。

输出

参数

维度

数据类型

格式

描述

dx

[-1,…,-1]

float16/float/bfloat16

ND

正向输入x的梯度。维度与x一致。

dgamma

[-1,…,-1]

float

ND

正向输入gamma的梯度。维度与gamma一致。