开发者
下载
[object Object]

[object Object][object Object]undefined
[object Object]
  • 接口功能:MhcPreBackward是MhcPre的反向算子,MhcPre算子基于一系列计算得到mHC(Manifold-Constrained Hyper-Connections)架构中的HresH^{res}HpostH^{post}投影矩阵以及Atten或MLP层的输入矩阵hinh^{in}

  • 计算公式:

    • 输出组合梯度计算

      • 正向公式: H_in=i=1Nx[B,S,i,:]H_pren[B,S,i]H\_in = \sum_{i=1}^{N} x[{B,S,i,:}] · H\_pre_n[B,S,i]
      • 反向计算:H_pre_grad=Reduce(H_in_grad.unsqueeze(2)x,dim=1)([B,S,N])x_grad_vec3=H_in_grad×H_pre([B,S,N,D])\begin{aligned} H\_pre\_grad &= \text{Reduce}\left(H\_in\_grad.\text{unsqueeze}(-2) \odot x, \text{dim}=-1\right) \quad ([B,S,N]) \\ x\_grad\_vec3 &= H\_in\_grad \times H\_pre \quad ([B,S,N,D]) \end{aligned}
    • Sigmoid门控反向(H_pre)

      • 正向公式: H_pre=Sigmoid(α_preH_pre_1+bias_pre)+hc_epsH\_pre = \text{Sigmoid}(\alpha\_pre * H\_pre\_1 + bias\_pre) + hc\_eps
      • 反向计算:s=H_prehc_epsH_pre_2_grad=H_pre_grads(1s)H_pre_1_grad=H_pre_2_gradα_preα_pre_grad=b,s,nB,S,N(H_pre_2_gradH_pre_1)bias_pre_grad=b,sB,SH_pre_2_grad([N])\begin{aligned} s &= H\_pre - hc\_eps \\ H\_pre\_2\_grad &= H\_pre\_grad \odot s \odot (1 - s) \\ H\_pre\_1\_grad &= H\_pre\_2\_grad \cdot \alpha\_pre \\ \alpha\_pre\_grad &= \sum_{b,s,n}^{B,S,N} \left(H\_pre\_2\_grad \cdot H\_pre\_1\right) \\ bias\_pre\_grad &= \sum_{b,s}^{B,S} H\_pre\_2\_grad \quad ([N]) \end{aligned}
    • Sigmoid门控反向(H_post)

      • 正向公式: H_post=Sigmoid(α_postH_post_1+bias_post)2H\_post = \text{Sigmoid}(\alpha\_post * H\_post\_1 + bias\_post) * 2
      • 反向计算:H_post_2_grad=H_post_grad(H_post(1H_post2))H_post_1_grad=H_post_2_gradαpostαpost_grad=b,s,nB,S,N(H_post_2_gradHpost_1)bias_post_grad=b,sB,SH_post_2_grad([N])\begin{aligned} H\_post\_2\_grad &= H\_post\_grad \odot \left(H\_post \cdot \left(1 - \frac{H\_post}{2}\right)\right) \\ H\_post\_1\_grad &= H\_post\_2\_grad \cdot \alpha_{post} \\ \alpha_{post\_grad} &= \sum_{b,s,n}^{B,S,N} \left(H\_post\_2\_grad \cdot H_{post\_1}\right) \\ bias\_post\_grad &= \sum_{b,s}^{B,S} H\_post\_2\_grad \quad ([N]) \end{aligned}
    • 残差连接反向(H_res)

      • 正向公式: H_res=α_resH_res_1+bias_resH\_res = \alpha\_res * H\_res\_1 + bias\_res
      • 反向计算:H_res_2_grad=H_res_gradαres([B,S,N,N])α_res_grad=b,s,i,jB,S,N,N(H_res_gradH_res_2)bias_res_grad=b,sB,SH_res_grad([N,N])H_res_1_grad=Reshape(H_res_2_grad)([B,S,N2])\begin{aligned} H\_res\_2\_grad &= H\_res\_grad \cdot \alpha_{res} \quad ([B,S,N,N]) \\ \alpha\_res\_grad &= \sum_{b,s,i,j}^{B,S,N,N} \left(H\_res\_grad \cdot H\_res\_2\right) \\ bias\_res\_grad &= \sum_{b,s}^{B,S} H\_res\_grad \quad ([N,N]) \\ H\_res\_1\_grad &= \text{Reshape}(H\_res\_2\_grad) \quad ([B,S,N^2]) \end{aligned}
    • RMSNorm Fusion反向

      • 正向公式: H_mix_tmp=H_mixinv_rmsH\_mix\_tmp = H\_mix * inv\_rms
      • 反向计算:H_mix_tmp_grad=Concat(H_pre_1_grad,H_post_1_grad,H_res_1_grad)([B,S,2N+N2])H_mix_grad=H_mix_tmp_gradinv_rmsinv_rmsgrad=last_dim(H_mix_tmp_gradH_mix)([B,S,1])\begin{aligned} H\_mix\_tmp\_grad &= \text{Concat}(H\_pre\_1\_grad, H\_post\_1\_grad, H\_res\_1\_grad) \quad ([B,S,2N+N^2]) \\ H\_mix\_grad &= H\_mix\_tmp\_grad \cdot inv\_rms \\ inv\_rms_{grad} &= \sum_{\text{last\_dim}} \left(H\_mix\_tmp\_grad \cdot H\_mix\right) \quad ([B,S,1]) \end{aligned}
    • 矩阵乘法反向

      • 正向公式: H_mix=x_rs@phiTH\_mix = x\_rs @ phi^T x_rs=xgammax\_rs = x * gamma
      • 反向计算:x_rs_grad=H_mix_grad@phi([B,S,ND])X=Reshape(x_rs,[BS,ND])G=Reshape(H_mix_grad,[BS,2N+N2])phigrad=GT@X([2N+N2,ND])\begin{aligned} x\_rs\_grad &= H\_mix\_grad @ phi \quad ([B,S,ND]) \\ X &= \text{Reshape}(x\_rs, [B\cdot S, ND]) \\ G &= \text{Reshape}(H\_mix\_grad, [B\cdot S, 2N+N^2]) \\ phi_{grad} &= G^T @ X \quad ([2N+N^2, ND]) \end{aligned}
    • 特征缩放反向

      • 正向公式: x_rs=xgammax\_rs = x * gamma
      • 反向计算:x_grad_mm=x_rs_gradgammagamma_grad=b=1Bs=1S(xx_rs_grad)([N,D])\begin{aligned} x\_grad\_mm &= x\_rs\_grad * gamma \\ gamma\_grad &= \sum_{b=1}^{B}\sum_{s=1}^{S} (x * x\_rs\_grad)\quad ([N,D]) \end{aligned}
    • RMS归一化梯度计算

      • 正向公式:inv_rms=11ni=1nxi2+eps,其中 n=NDinv\_rms = \frac{1}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}x_i^2 + eps}}, \quad 其中\ n = N * D
      • 反向计算:x_rs_grad_inv=(inv_rms_gradinv_rms3ND)x_rsx_rs_grad=x_grad_mm+x_rs_grad_invx_grad_vec1=Reshape(x_rs_grad,[B,S,N,D])x_grad=x_grad_vec3+x_grad_vec1\begin{aligned} x\_rs\_grad\_inv &= - \left(\frac{inv\_rms\_grad \cdot {inv\_rms}^3}{N*D}\right) \cdot x\_rs \\ x\_rs\_grad &= x\_grad\_mm + x\_rs\_grad\_inv \\ x\_grad\_vec1 &= \text{Reshape}(x\_rs\_grad, [B,S,N,D]) \\ x\_grad &= x\_grad\_vec3 + x\_grad\_vec1 \end{aligned}
    • 融合mhc_post的grad_x相加操作

      x_grad=x_grad+grad_x_post\begin{aligned} x\_grad &= x\_grad + grad\_x\_post \end{aligned}
[object Object]

每个算子分为,必须先调用“aclnnMhcPreBackwardGetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnMhcPreBackward”接口执行计算。

[object Object]
[object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值

    返回aclnnStatus状态码,具体参见

[object Object]
  • 确定性计算:

    • aclnnMhcPreBackward默认采用确定性实现。
  • 规格约束

    [object Object]
[object Object]

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]