开发者
下载
[object Object][object Object][object Object]undefined
[object Object]
  • 接口功能:MhcPreSinkhornBackward是MhcPreSinkhorn的反向算子。计算对应的梯度反向传播。

  • 计算公式

    • 第一部分计算H_res_grad:设当前梯度为 gradcurr=gradHRes\mathbf{grad}_{\text{curr}} = \mathbf{gradHRes}[object Object]为输入参数,即正向输出[object Object]的梯度),共进行 num_iters\mathbf{num\_iters}(Sinkhorn迭代次数,对应[object Object],当前仅支持20)次迭代。

      • num_iters1\mathbf{num\_iters} - 1 次迭代:对迭代编号 i=num_iters1,num_iters2,,1i = \mathbf{num\_iters} - 1, \mathbf{num\_iters} - 2, \dots, 1,依次执行,其中 sumOut\mathbf{sumOut}(对应输入参数[object Object])为Sinkhorn变换正向计算保存的中间sum结果,normOut\mathbf{normOut}(对应输入参数[object Object])为Sinkhorn变换正向计算保存的中间norm结果,dotProdk\mathbf{dotProd}_k 为第 kk 次点积计算结果(中间变量):

        dotProd2i+1=dim=2,keepdim=True(gradcurrnormOut2i+1)gradcurrgradcurrdotProd2i+1sumOut2i+1,dotProd2i=dim=1,keepdim=True(gradcurrnormOut2i),gradcurrgradcurrdotProd2isumOut2i,\begin{aligned} \mathbf{dotProd}_{2i+1} &= \sum_{\dim=-2,\text{keepdim}=\text{True}} \left( \mathbf{grad}_{\text{curr}} \cdot \mathbf{normOut}_{2i+1} \right) \\ \\ \mathbf{grad}_{\text{curr}} &\gets \frac{\mathbf{grad}_{\text{curr}} - \mathbf{dotProd}_{2i+1}}{\mathbf{sumOut}_{2i+1}}, \\ \\ \mathbf{dotProd}_{2i} &= \sum_{\dim=-1,\text{keepdim}=\text{True}} \left( \mathbf{grad}_{\text{curr}} \cdot \mathbf{normOut}_{2i} \right), \\ \\ \mathbf{grad}_{\text{curr}} &\gets \frac{\mathbf{grad}_{\text{curr}} - \mathbf{dotProd}_{2i}}{\mathbf{sumOut}_{2i}}, \\ \end{aligned}
      • 最后一次迭代

        dotProd1=dim=2,keepdim=True(gradcurrnormOut1),gradcurrgradcurrdotProd1sumOut1,dotProd0=dim=1,keepdim=True(gradcurrnormOut0),HResGrad(gradcurrdotProd0)normOut0,\begin{aligned} \mathbf{dotProd}_{1} &= \sum_{\dim=-2,\text{keepdim}=\text{True}} \left( \mathbf{grad}_{\text{curr}} \cdot \mathbf{normOut}_{1} \right), \\ \\ \mathbf{grad}_{\text{curr}} &\gets \frac{\mathbf{grad}_{\text{curr}} - \mathbf{dotProd}_{1}}{\mathbf{sumOut}_{1}}, \\ \\ \mathbf{dotProd}_{0} &= \sum_{\dim=-1,\text{keepdim}=\text{True}} \left( \mathbf{grad}_{\text{curr}} \cdot \mathbf{normOut}_{0} \right), \\ \\ \mathbf{HResGrad} &\gets \left( \mathbf{grad}_{\text{curr}} - \mathbf{dotProd}_{0} \right) \cdot \mathbf{normOut}_{0}, \\ \end{aligned}
    • 输入gradHin[B,S,C]\mathbf{gradHin} \in [B, S, C][object Object]为输入参数,即输出[object Object]的梯度)

    • 输出组合梯度计算

      • 正向计算公式:其中 x\mathbf{x} 为输入参数(前向输入x),hPre\mathbf{hPre} 为输入参数(前向保存的中间结果h_pre),NN 为输入Tensor中N维度的大小(当前仅支持4)。
      HIn=i=1Nx[B,S,i,:]hPre[B,S,i]\mathbf{HIn} = \sum_{i=1}^{N} \mathbf{x}[B, S, i, :] \cdot \mathbf{hPre}[B, S, i]
      • 反向计算公式
      HPreGrad=Reduce(gradHin.unsqueeze(2)x,dim=1)([B,S,N])xGradVec3=gradHin×hPre([B,S,N,C])\begin{aligned} \mathbf{HPreGrad} &= \text{Reduce}\left(\mathbf{gradHin}.\text{unsqueeze}(-2) \odot \mathbf{x}, \text{dim}=-1\right) \quad ([B,S,N]) \\ \mathbf{xGradVec3} &= \mathbf{gradHin} \times \mathbf{hPre} \quad ([B,S,N,C]) \end{aligned}
    • 门控激活层梯度计算

      • Sigmoid门控反向(H_pre)

        • 正向公式:

          hPre=Sigmoid(alphaPrehPre1+biasPre)+hcEps\mathbf{hPre} = \text{Sigmoid}(\mathbf{alphaPre} \cdot \mathbf{hPre1} + \mathbf{biasPre}) + \mathbf{hcEps}

          其中 alphaPre\mathbf{alphaPre} 为输入参数[object Object]的第一个元素(对应H_pre的缩放系数),biasPre\mathbf{biasPre} 为输入参数[object Object]的前N个元素,hcEps\mathbf{hcEps} 为输入参数[object Object](数值稳定性参数)。

        • 反向计算:

          s=hPrehcEpsHPre2Grad=HPreGrads(1s)HPre1Grad=HPre2GradalphaPrealphaPreGrad=b,s,nB,S,N(HPre2GradhPre1)biasPreGrad=b,sB,SHPre2Grad([N])\begin{aligned} \mathbf{s} &= \mathbf{hPre} - \mathbf{hcEps} \\ \mathbf{HPre2Grad} &= \mathbf{HPreGrad} \odot \mathbf{s} \odot (1 - \mathbf{s}) \\ \mathbf{HPre1Grad} &= \mathbf{HPre2Grad} \cdot \mathbf{alphaPre} \\ \mathbf{alphaPreGrad} &= \sum_{b,s,n}^{B,S,N} \left(\mathbf{HPre2Grad} \cdot \mathbf{hPre1}\right) \\ \mathbf{biasPreGrad} &= \sum_{b,s}^{B,S} \mathbf{HPre2Grad} \quad ([N]) \end{aligned}
      • Sigmoid门控反向(H_post)

        • 正向公式:

          hPost=Sigmoid(alphaPosthPost1+biasPost)×2\mathbf{hPost} = \text{Sigmoid}(\mathbf{alphaPost} \cdot \mathbf{hPost1} + \mathbf{biasPost}) \times 2

          其中 alphaPost\mathbf{alphaPost} 为输入参数[object Object]的第二个元素,biasPost\mathbf{biasPost} 为输入参数[object Object]的中间N个元素。

        • 反向计算:

          HPost2Grad=gradHPost(hPost(1hPost2))HPost1Grad=HPost2GradalphaPostalphaPostGrad=b,s,nB,S,N(HPost2GradhPost1)biasPostGrad=b,sB,SHPost2Grad([N])\begin{aligned} \mathbf{HPost2Grad} &= \mathbf{gradHPost} \odot \left(\mathbf{hPost} \cdot \left(1 - \frac{\mathbf{hPost}}{2}\right)\right) \\ \mathbf{HPost1Grad} &= \mathbf{HPost2Grad} \cdot \mathbf{alphaPost} \\ \mathbf{alphaPostGrad} &= \sum_{b,s,n}^{B,S,N} \left(\mathbf{HPost2Grad} \cdot \mathbf{hPost1}\right) \\ \mathbf{biasPostGrad} &= \sum_{b,s}^{B,S} \mathbf{HPost2Grad} \quad ([N]) \end{aligned}

          其中 gradHPost\mathbf{gradHPost} 为输入参数(输出[object Object]的梯度)。

      • 残差连接反向(H_res)

        • 正向公式:

          hRes=alphaReshRes1+biasRes\mathbf{hRes} = \mathbf{alphaRes} \cdot \mathbf{hRes1} + \mathbf{biasRes}

          其中 alphaRes\mathbf{alphaRes} 为输入参数[object Object]的第三个元素,biasRes\mathbf{biasRes} 为输入参数[object Object]的后 N2N^2 个元素。

        • 反向计算:

          HRes2Grad=HResGradalphaRes([B,S,N,N])alphaResGrad=b,s,i,jB,S,N,N(HResGradhRes2)biasResGrad=b,sB,SHResGrad([N,N])HRes1Grad=Reshape(HRes2Grad)([B,S,N2])\begin{aligned} \mathbf{HRes2Grad} &= \mathbf{HResGrad} \cdot \mathbf{alphaRes} \quad ([B,S,N,N]) \\ \mathbf{alphaResGrad} &= \sum_{b,s,i,j}^{B,S,N,N} \left(\mathbf{HResGrad} \cdot \mathbf{hRes2}\right) \\ \mathbf{biasResGrad} &= \sum_{b,s}^{B,S} \mathbf{HResGrad} \quad ([N,N]) \\ \mathbf{HRes1Grad} &= \text{Reshape}(\mathbf{HRes2Grad}) \quad ([B,S,N^2]) \end{aligned}
    • RMS归一化融合反向

      • RMSNorm Fusion反向

        • 正向公式:

          hMixTmp=hMixinvRms\mathbf{hMixTmp} = \mathbf{hMix} \cdot \mathbf{invRms}

          其中 hMix\mathbf{hMix} 为线性投影层输出,invRms\mathbf{invRms} 为输入参数(前向保存的中间结果inv_rms)。

        • 反向计算:

          hMixTmpGrad=Concat(HPre1Grad,HPost1Grad,HRes1Grad)([B,S,2N+N2])hMixGrad=hMixTmpGradinvRmsinvRmsGrad=last_dim(hMixTmpGradhMix)([B,S,1])\begin{aligned} \mathbf{hMixTmpGrad} &= \text{Concat}(\mathbf{HPre1Grad}, \mathbf{HPost1Grad}, \mathbf{HRes1Grad}) \quad ([B,S,2N+N^2]) \\ \mathbf{hMixGrad} &= \mathbf{hMixTmpGrad} \cdot \mathbf{invRms} \\ \mathbf{invRmsGrad} &= \sum_{\text{last\_dim}} \left(\mathbf{hMixTmpGrad} \cdot \mathbf{hMix}\right) \quad ([B,S,1]) \end{aligned}
    • 线性投影层梯度计算

      • 矩阵乘法反向

        • 正向公式:

          hMix=xRs@phiT\mathbf{hMix} = \mathbf{xRs} @ \mathbf{phi}^T xRs=xgamma\mathbf{xRs} = \mathbf{x} \cdot \mathbf{gamma}

          其中 phi\mathbf{phi} 为输入参数(前向参数phi),gamma\mathbf{gamma} 为RMS归一化的缩放参数(由输入 x\mathbf{x} 计算得到)。

        • 反向计算:

          xRsGrad=hMixGrad@phi([B,S,NC])X=Reshape(xRs,[BS,NC])G=Reshape(hMixGrad,[BS,2N+N2])phiGrad=GT@X([2N+N2,NC])\begin{aligned} \mathbf{xRsGrad} &= \mathbf{hMixGrad} @ \mathbf{phi} \quad ([B,S,NC]) \\ \mathbf{X} &= \text{Reshape}(\mathbf{xRs}, [B \cdot S, NC]) \\ \mathbf{G} &= \text{Reshape}(\mathbf{hMixGrad}, [B \cdot S, 2N+N^2]) \\ \mathbf{phiGrad} &= \mathbf{G}^T @ \mathbf{X} \quad ([2N+N^2, NC]) \end{aligned}
      • 特征缩放反向

        • 正向公式:

          xRs=xgamma\mathbf{xRs} = \mathbf{x} \cdot \mathbf{gamma}
        • 反向计算:

          xGradMm=xRsGradgammagammaGrad=b=1Bs=1S(xxRsGrad)([N,C])\begin{aligned} \mathbf{xGradMm} &= \mathbf{xRsGrad} \cdot \mathbf{gamma} \\ \mathbf{gammaGrad} &= \sum_{b=1}^{B}\sum_{s=1}^{S} (\mathbf{x} \cdot \mathbf{xRsGrad}) \quad ([N,C]) \end{aligned}
    • RMS归一化梯度计算

      • 正向公式:

        invRms=11ni=1nxi2+eps,其中 n=N×C\mathbf{invRms} = \frac{1}{\sqrt{\frac{1}{n}\sum_{i=1}^{n}\mathbf{x}_i^2 + \mathbf{eps}}}, \quad 其中\ n = N \times C

        其中 eps\mathbf{eps} 为RMS归一化的数值稳定性参数。

      • 反向计算:

        xRsGradInv=(invRmsGradinvRms3N×C)xRsxRsGrad=xGradMm+xRsGradInvxGradVec1=Reshape(xRsGrad,[B,S,N,C])xGrad=xGradVec3+xGradVec1\begin{aligned} \mathbf{xRsGradInv} &= - \left(\frac{\mathbf{invRmsGrad} \cdot \mathbf{invRms}^3}{N \times C}\right) \cdot \mathbf{xRs} \\ \mathbf{xRsGrad} &= \mathbf{xGradMm} + \mathbf{xRsGradInv} \\ \mathbf{xGradVec1} &= \text{Reshape}(\mathbf{xRsGrad}, [B,S,N,C]) \\ \mathbf{xGrad} &= \mathbf{xGradVec3} + \mathbf{xGradVec1} \end{aligned}
    • 符号说明

      [object Object]undefined
[object Object]

每个算子分为,必须先调用[object Object]接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用[object Object]执行实际计算。

[object Object]
[object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值:

    aclnnStatus:返回状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

[object Object]
  • 确定性计算

    • aclnnMhcPreSinkhornBackward默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
  • 规格约束

    [object Object]undefined
[object Object]

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]