昇腾社区首页
中文
注册
开发者
下载

(beta)torch_npu.contrib.module.ChannelShuffle

产品支持情况

产品 是否支持
[object Object]Atlas A3 训练系列产品[object Object]
[object Object]Atlas A2 训练系列产品[object Object]
[object Object]Atlas 推理系列产品[object Object]
[object Object]Atlas 训练系列产品[object Object]

功能说明

  • API功能:应用NPU兼容的通道shuffle操作。

  • 等价计算逻辑:

    可使用cpu_channel_shuffle等价替换torch_npu.contrib.module.ChannelShuffle,两者计算逻辑一致。

    [object Object]

函数原型

[object Object]

参数说明

计算参数

  • in_channels (int):必选参数。输入张量中的通道总数。
  • groups (int):可选参数。shuffle组数。默认值为2。
  • split_shuffle (bool):可选参数。shuffle后是否执行chunk操作。默认值为True。

计算输入

  • x1 (Tensor):输入张量。 shape为(N,Cin,Lin)(N, C_{in}, L_{in})
  • x2 (Tensor):输入张量。 shape为(N,Cin,Lin)(N, C_{in}, L_{in})

返回值说明

  • out1 (Tensor):输出张量。 shape为(N,Cout,Lout)(N, C_{out}, L_{out})
  • out2 (Tensor):输出张量。 shape为(N,Cout,Lout)(N, C_{out}, L_{out})

约束说明

只实现了groups为2场景,请自行修改其他groups场景。

调用示例

[object Object]