开发者
下载
[object Object][object Object][object Object]undefined
[object Object]
  • 接口功能:mhc_post基于一系列计算对MHC(Manifold-Constrained Hyper-Connection)架构中上一层输出htouth_{t}^{out}进行Post Mapping,对上一层的输入xjx_j进行ResMapping,然后对二者进行残差连接,得到下一层的输入xl+1x_{l+1}。该算子实现前述过程的反向功能。

  • 计算公式:

    grad_x=Hlres×grad_outputgrad_h_res=xl×grad_outputTgrad\_x = H_{l}^{res} \times grad\_output\\ grad\_h\_res = x_{l} \times {grad\_output}^{T} grad_h_out=(grad_output(Hlpost.unsqueeze(1))).sum(dim=2)grad_h_post=(grad_output(hlout.unsqueeze(2))).sum(dim=1)grad\_h\_out=({grad\_output} * (H_{l}^{post}.unsqueeze(-1))).sum(dim=-2)\\ grad\_h\_post=({grad\_output} * (h_{l}^{out}.unsqueeze(-2))).sum(dim=-1)
[object Object]

算子执行接口为,必须先调用“aclnnMhcPostBackwardGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnMhcPostBackward”接口执行计算。

[object Object]
[object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]
[object Object]
  • 参数说明:

    [object Object]
  • 返回值:

    返回aclnnStatus状态码,具体参见

[object Object]

参数说明中维度N的取值目前仅支持4、6和8。

[object Object]

示例代码如下,仅供参考,具体编译和执行过程请参考

[object Object]