torch_npu.distributed.reduce_scatter_tensor_uneven(output, input, input_split_sizes =None, op=dist.ReduceOp.SUM, group=None, async_op=False) -> torch.distributed.distributed_c10d.Work
参考原生接口torch.distributed.reduce_scatter_tensor功能,torch_npu.distributed.reduce_scatter_tensor_uneven接口新增支持零拷贝和非等长切分功能。
input的shape为所有卡上output的shape拼接大小。
output的shape无特殊约束。
此接口仅可在单机场景下使用。
input_split_sizes元素之和等于input的0维;input_split_sizes元素个数等于group的size。
创建以下文件test.py并保存。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import os import torch import torch_npu import torch.distributed as dist dist.init_process_group(backend="hccl") rank = int(os.getenv('LOCAL_RANK')) torch.npu.set_device(rank) input_split_sizes = [2, 3] input_tensor = torch.ones(sum(input_split_sizes), dtype=torch.int32).npu() output_tensor = torch.zeros(input_split_sizes[rank], dtype=torch.int32).npu() torch_npu.distributed.reduce_scatter_tensor_uneven( output_tensor, input_tensor, input_split_sizes=input_split_sizes, async_op=False ) |
执行如下命令。
torchrun --nproc-per-node=2 test.py