昇腾社区首页
中文
注册
开发者
下载

RmsNormBackwardOperation

产品支持情况

产品

是否支持

Atlas A3 推理系列产品/Atlas A3 训练系列产品

Atlas A2 训练系列产品/Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

rmsnorm的反向计算。

定义

1
2
3
struct RmsNormBackwardParam {
    uint8_t rsv[8] = {0};
};

参数列表

成员名称

类型

默认值

描述

rsv[8]

uint8_t

{0}

预留参数。

输入

参数

维度

数据类型

格式

描述

dy

[dim_0, dim_1, ..., dim_n]

float16/float/bf16

ND

输入梯度。数据类型、维度与x一致。

x

[dim_0, dim_1, ..., dim_n]

float16/float/bf16

ND

正向计算输入。

rstd

[dim_0, dim_1, ..., dim_n]

float

ND

正向计算中间结果。

gamma

[dim_0, dim_1, ..., dim_n]

float16/float/bf16

ND

数据类型与x一致。维度数需要大于0,并小于x的维度数,gamma的维度从最后一维向前,每一维都需要和x保持一致。

输出

参数

维度

数据类型

格式

描述

dx

[dim_0, dim_1, ..., dim_n]

float16/float/bf16

ND

正向输入x的梯度。数据类型、维度与x一致。

dgamma

[dim_0, dim_1, ..., dim_n]

float

ND

正向输入gamma的梯度。维度与gamma一致。

约束说明