开发者
下载
[object Object][object Object][object Object]undefined
[object Object]
  • 接口功能:对mHC架构中的Hres\mathbf{H}'_{\text{res}}矩阵执行Sinkhorn迭代归一化变换,最终得到双随机矩阵Hres\mathbf{H}_{\text{res}};支持输出迭代过程中的中间归一化结果(norm_out)和求和结果(sum_out),用于反向梯度计算。
[object Object]

Sinkhorn变换共执行num_iters\mathbf{num\_iters}次迭代,迭代过程中生成中间归一化结果norm_out[k]\mathbf{norm\_out}[k]和求和结果sum_out[k]\mathbf{sum\_out}[k],最终输出最后一次迭代的norm_out\mathbf{norm\_out}作为变换结果。

第一次迭代(初始化):

norm_out[0]=softmax(x,dim=1)+ϵ,sum_out[1]=dim=2,keepdim=Truenorm_out[0]+ϵ,norm_out[1]=norm_out[0]sum_out[1],\begin{aligned} \mathbf{norm\_out}[0] &= \text{softmax}(\mathbf{x}, \dim=-1) + \epsilon, \\ \mathbf{sum\_out}[1] &= \sum_{\dim=-2,\text{keepdim}=\text{True}} \mathbf{norm\_out}[0] + \epsilon, \\ \mathbf{norm\_out}[1] &= \frac{\mathbf{norm\_out}[0]}{\mathbf{sum\_out}[1]}, \\ \end{aligned}

ii次迭代(i=1,2,,(num_iters1)i = 1, 2, \dots, \mathbf({num\_iters}-1)):

sum_out[2i]=dim=1,keepdim=Truenorm_out[2i1]+ϵ,norm_out[2i]=norm_out[2i1]sum_out[2i],sum_out[2i+1]=dim=2,keepdim=Truenorm_out[2i]+ϵ,norm_out[2i+1]=norm_out[2i]sum_out[2i+1],\begin{aligned} \mathbf{sum\_out}[2i] &= \sum_{\dim=-1,\text{keepdim}=\text{True}} \mathbf{norm\_out}[2i-1] + \epsilon, \\ \mathbf{norm\_out}[2i] &= \frac{\mathbf{norm\_out}[2i-1]}{\mathbf{sum\_out}[2i]}, \\ \mathbf{sum\_out}[2i+1] &= \sum_{\dim=-2,\text{keepdim}=\text{True}} \mathbf{norm\_out}[2i] + \epsilon, \\ \mathbf{norm\_out}[2i+1] &= \frac{\mathbf{norm\_out}[2i]}{\mathbf{sum\_out}[2i+1]}, \\ \end{aligned}
  • 最终输出
output=norm_out[2×num_iters1]\text{output} = \mathbf{norm\_out}[2 \times \mathbf{num\_iters} - 1]
  • 🔍 符号说明
[object Object]undefined
[object Object]

算子采用两段式接口调用:需先调用[object Object]获取计算所需的Device侧内存大小,再调用[object Object]执行实际计算。

[object Object]
[object Object]
[object Object]
  • 参数说明

    [object Object]undefined
  • 返回值

    返回aclnnStatus状态码,具体参见

    第一段接口完成入参校验,出现以下场景时报错:

    [object Object]undefined
[object Object]
  • 参数说明

    [object Object]undefined
  • 返回值

    返回[object Object]状态码,具体参见

[object Object]
  • 确定性计算

    • aclnnMhcSinkhorn默认采用确定性实现,相同输入多次调用结果一致。
  • 公共约束

    1. 输入约束:
      • 输入Tensor [object Object]为空,报错[object Object]
      • 所有输入/输出Tensor的数据格式仅支持[object Object]
      • 仅支持[object Object]数据类型,不支持其他精度(如FLOAT16/DOUBLE)。
      • outFlag支持0和1;outFlag为0时,仅输出output,normOut和sumOut可传空指针;outFlag为1时,同时输出output、normOut和sumOut,normOut和sumOut不能为空。
      • 输入-inf/inf/nan/,输出nan/nan/nan。
    2. 内存约束:
      • Workspace内存需在Device侧申请,且大小需严格匹配第一段接口返回值;
  • 规格约束

    [object Object]undefined
[object Object]

以下为C++调用示例,需结合AscendCL环境编译运行,具体流程参考

[object Object]