(beta)torch_npu.contrib.module.ChannelShuffle
产品支持情况
| 产品 | 是否支持 |
|---|---|
| [object Object]Atlas A3 训练系列产品[object Object] | √ |
| [object Object]Atlas A2 训练系列产品[object Object] | √ |
| [object Object]Atlas 推理系列产品[object Object] | √ |
| [object Object]Atlas 训练系列产品[object Object] | √ |
功能说明
API功能:应用NPU兼容的通道shuffle操作。
等价计算逻辑:
可使用
cpu_channel_shuffle等价替换torch_npu.contrib.module.ChannelShuffle,两者计算逻辑一致。[object Object]
函数原型
[object Object]
参数说明
计算参数
- in_channels (
int):必选参数。输入张量中的通道总数。 - groups (
int):可选参数。shuffle组数。默认值为2。 - split_shuffle (
bool):可选参数。shuffle后是否执行chunk操作。默认值为True。
计算输入
- x1 (
Tensor):输入张量。 shape为。 - x2 (
Tensor):输入张量。 shape为。
返回值说明
- out1 (
Tensor):输出张量。 shape为。 - out2 (
Tensor):输出张量。 shape为。
约束说明
只实现了groups为2场景,请自行修改其他groups场景。
调用示例
[object Object]