昇腾社区首页
中文
注册

aclnnRmsNormGrad

支持的产品型号

  • Atlas A2训练系列产品/Atlas 800I A2推理产品。

接口原型

每个算子分为两段式接口,必须先调用“aclnnRmsNormGradGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnRmsNormGrad”接口执行计算。

  • aclnnStatus aclnnRmsNormGradGetWorkspaceSize( const aclTensor *dy, const aclTensor *x, const aclTensor *rstd, const aclTensor *gamma, const aclTensor *dxOut, const aclTensor *dgammaOut, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnRmsNormGrad( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

功能描述

  • 算子功能:aclnnRmsNorm的反向计算。

  • 算子公式:

    • 正向公式:
    RmsNorm(xi)=xiRms(x)gi, where Rms(x)=1ni=1nxi2+eps\operatorname{RmsNorm}(x_i)=\frac{x_i}{\operatorname{Rms}(\mathbf{x})} g_i, \quad \text { where } \operatorname{Rms}(\mathbf{x})=\sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2+eps}
    • 反向推导:
    dxi=(dyigixiRms(x)Mean(y))1Rms(x), where Mean(y)=1ni=1n(dygixiRms(x))dx_i= (dy_i * g_i - \frac{x_i}{\operatorname{Rms}(\mathbf{x})} * \operatorname{Mean}(\mathbf{y})) * \frac{1} {\operatorname{Rms}(\mathbf{x})}, \quad \text { where } \operatorname{Mean}(\mathbf{y}) = \frac{1}{n}\sum_{i=1}^n (dy * g_i * \frac{x_i}{\operatorname{Rms}(\mathbf{x})}) dgi=xiRms(x)dyidg_i = \frac{x_i}{\operatorname{Rms}(\mathbf{x})} dy_i

aclnnRmsNormGradGetWorkspaceSize

  • 参数说明:

    • dy(aclTensor*,计算输入, 表示反向传回的梯度): 数据类型支持FLOAT32、FLOAT16、BFLOAT16、shape支持2-8维度,数据格式支持ND, 支持非连续输入。
    • x(aclTensor*,计算输入, 正向算子的输入,表示被标准化的数据): 数据类型支持FLOAT32、FLOAT16、BFLOAT16、shape支持2-8维度,数据格式支持ND, shape与dy保持一致, 支持非连续输入。
    • rstd(aclTensor*,计算输入, 正向算子的中间计算结果): 数据类型支持FLOAT32、shape支持2-8维度,数据格式支持ND, shape需要满足gamma_shape = x_shape[0:n], n < x_shape.dims(), n与gamma一致, 支持非连续输入。
    • gamma(aclTensor*,计算输入, 正向算子的输入,计算入参): 数据类型支持FLOAT32、FLOAT16、BFLOAT16、shape支持2-8维度,数据格式支持ND, shape需要满足gamma_shape = x_shape[n:], n < x_shape.dims(), 支持非连续输入。
    • dxOut(aclTensor*,计算输出, 表示输入x的梯度): 数据类型支持FLOAT32、FLOAT16、BFLOAT16、shape支持2-8维度,数据格式支持ND, shape与dy保持一致。
    • dgammaOut(aclTensor*,计算输出, 表示gamma的梯度): 数据类型支持FLOAT32、shape支持2-8维度,数据格式支持ND, shape与gamma保持一致。
    • workspaceSize(uint64_t*,出参):返回需要在Device侧申请的workspace大小。
    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
  • 返回值:

    • aclnnStatus:返回状态码。

      说明: 第一段接口完成入参校验,出现以下场景时报错:

      • 161001 (ACLNN_ERR_PARAM_NULLPTR):如果传入参数是必选输入,输出或者必选属性,且是空指针,则返回161001。

aclnnRmsNormGrad

  • 参数说明:

    • workspace(void*, 入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t, 入参):在Device侧申请的workspace大小,由第一段接口aclnnRmsNormGradGetWorkspaceSize获取。
    • executor(aclOpExecutor*, 入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream, 入参):指定执行任务的AscendCL stream流。
  • 返回值:

    • aclnnStatus:返回状态码。(具体参见undefined

约束与限制

  • 功能维度
    • 数据类型支持
      • dy、x、gamma支持:FLOAT32、FLOAT16、BFLOAT16。
      • rstd支持:FLOAT32。
    • 数据格式支持:ND。
  • 未支持类型说明
    • DOUBLE:指令不支持DOUBLE。
    • 是否支持空tensor:不支持空进空出。
    • 是否非连续tensor:不支持输入非连续,不支持数据非连续。
  • 边界值场景说明
    • 当输入是inf时,输出为inf。
    • 当输入是nan时,输出为nan。 各平台支持数据类型说明
    • Atlas A2训练系列产品/Atlas 800I A2推理产品
      dy 数据类型 x 数据类型 rstd 数据类型 gamma 数据类型 dx 数据类型 dgamma 数据类型
      float16 float16 float32 float32 float16 float32
      bfloat16 bfloat16 float32 float32 bfloat16 float32
      float16 float16 float32 float16 float16 float32
      float32 float32 float32 float32 float32 float32
      bfloat16 bfloat16 float32 bfloat16 bfloat16 float32

调用示例

[object Object]