昇腾社区首页
中文
注册
开发者
下载

aclnnThnnFusedLstmCellBackward

产品支持情况

[object Object]undefined

功能说明

  • 算子功能:LSTMCell中四个门中matmul后剩余计算的反向传播,计算正向输出四个门激活前的值gates、输入cx、偏置b的梯度。
  • 计算公式:

变量定义

  • 输入梯度δht\delta h_t ([object Object]), δct\delta c_t ([object Object])
  • 前向缓存ifgoi,f,g,o (各门激活值[object Object]),ct1c_{t-1} ([object Object]),ctc_t ([object Object])
  • 输出梯度δaiδafδagδao\delta a_i,\delta a_f,\delta a_g,\delta a_o (存入 [object Object]),δct1\delta c_{t-1} ([object Object])

第一阶段:中间梯度与状态回传

首先计算隐藏状态对细胞状态的贡献,并汇总得到当前时刻细胞的总梯度 grad_ctotal\text{grad\_}c_{total}

gcx=tanh(ct)grad_ctotal=δhto(1gcx2)+δctδct1=grad_ctotalf\begin{aligned} gcx &= \tanh(c_t) \\ \text{grad\_}c_{total} &= \delta h_t \cdot o \cdot (1 - gcx^2) + \delta c_t \\ \delta c_{t-1} &= \text{grad\_}c_{total} \cdot f \end{aligned}

第二阶段:门控分量梯度 (Pre-activation)

根据代码逻辑,各门控在进入激活函数前的梯度 δa\delta a 计算如下:

δao=(δhtgcx)o(1o)δai=(grad_ctotalg)i(1i)δaf=(grad_ctotalct1)f(1f)δag=(grad_ctotali)(1g2)\begin{aligned} \delta a_o &= (\delta h_t \cdot gcx) \cdot o \cdot (1 - o) \\ \delta a_i &= (\text{grad\_}c_{total} \cdot g) \cdot i \cdot (1 - i) \\ \delta a_f &= (\text{grad\_}c_{total} \cdot c_{t-1}) \cdot f \cdot (1 - f) \\ \delta a_g &= (\text{grad\_}c_{total} \cdot i) \cdot (1 - g^2) \end{aligned}

第三阶段:参数梯度 (db)

**1. 偏置梯度 (db):**对 Batch 维度(NN)进行求和:

δb=n=1N[δaiδafδagδao]n\delta b = \sum_{n=1}^{N} \begin{bmatrix} \delta a_i \\ \delta a_f \\ \delta a_g \\ \delta a_o \end{bmatrix}_n

函数原型

每个算子分为,必须先调用“aclnnThnnFusedLstmCellBackwardGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnThnnFusedLstmCellBackward”接口执行计算。

[object Object]
[object Object]

aclnnThnnFusedLstmCellBackwardGetWorkspaceSize

  • 参数说明:

    [object Object]
  • 返回值: aclnnStatus: 返回状态码,具体参见。 第一段接口完成入参校验,出现以下场景时报错:

    [object Object]

aclnnThnnFusedLstmCellBackward

  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus: 返回状态码,具体参见

约束说明

  • 确定性计算:
    • aclnnThnnFusedLstmCellBackward默认确定性实现。
  • 边界值场景说明:
    • 当输入是Inf时,输出为NAN。
    • 当输入是NaN时,输出为NaN。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]