接口功能:MhcPreSinkhornBackward是MhcPreSinkhorn的反向算子。计算对应的梯度反向传播。
计算公式:
第一部分计算H_res_grad:设当前梯度为 (
[object Object]为输入参数,即正向输出[object Object]的梯度),共进行 (Sinkhorn迭代次数,对应[object Object],当前仅支持20)次迭代。前 次迭代:对迭代编号 ,依次执行,其中 (对应输入参数
[object Object])为Sinkhorn变换正向计算保存的中间sum结果,(对应输入参数[object Object])为Sinkhorn变换正向计算保存的中间norm结果, 为第 次点积计算结果(中间变量):最后一次迭代:
输入:(
[object Object]为输入参数,即输出[object Object]的梯度)输出组合梯度计算:
- 正向计算公式:其中 为输入参数(前向输入x), 为输入参数(前向保存的中间结果h_pre), 为输入Tensor中N维度的大小(当前仅支持4)。
- 反向计算公式:
门控激活层梯度计算:
Sigmoid门控反向(H_pre):
正向公式:
其中 为输入参数
[object Object]的第一个元素(对应H_pre的缩放系数), 为输入参数[object Object]的前N个元素, 为输入参数[object Object](数值稳定性参数)。反向计算:
Sigmoid门控反向(H_post):
正向公式:
其中 为输入参数
[object Object]的第二个元素, 为输入参数[object Object]的中间N个元素。反向计算:
其中 为输入参数(输出
[object Object]的梯度)。
残差连接反向(H_res):
正向公式:
其中 为输入参数
[object Object]的第三个元素, 为输入参数[object Object]的后 个元素。反向计算:
RMS归一化融合反向:
RMSNorm Fusion反向:
正向公式:
其中 为线性投影层输出, 为输入参数(前向保存的中间结果inv_rms)。
反向计算:
线性投影层梯度计算:
矩阵乘法反向:
正向公式:
其中 为输入参数(前向参数phi), 为RMS归一化的缩放参数(由输入 计算得到)。
反向计算:
特征缩放反向:
正向公式:
反向计算:
RMS归一化梯度计算:
正向公式:
其中 为RMS归一化的数值稳定性参数。
反向计算:
符号说明:
[object Object]undefined
每个算子分为,必须先调用[object Object]接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用[object Object]执行实际计算。
确定性计算:
- aclnnMhcPreSinkhornBackward默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
规格约束
[object Object]undefined