集合通信入图

功能简介

集合通信入图能够避免断图和拥有更大的成图范围,从而获得更大的资源调度与融合收益,并在整图层面进行通信与计算并行优化。

原生PyTorch社区对集合通信入图的支持度并不完善,功能一直在增强中。TorchAir针对该现状,采用如下方案增强这块入图能力:

具体入图操作参见使用方法(PyTorch 2.1版本)。目前支持入图的API情况如表1所示,请根据实际情况按需调用。

表1 集合通信API入图支持情况

PyTorch集合通信API

PyTorch 2.1版本

说明

torch.distributed.all_gather

接口详细的介绍请参考API参考

torch.distributed.all_gather_into_tensor

torch.distributed.all_reduce

torch.distributed.all_to_all

torch.distributed.all_to_all_single

torch.distributed.broadcast

torch.distributed.reduce_scatter_tensor

x

torch_npu.distributed.all_gather_into_tensor_uneven

x

torch_npu.distributed.reduce_scatter_tensor_uneven

x

使用约束

集合通信算子入图的前提是PyTorch脚本中所有算子能正常以Eager模式运行。

使用方法(PyTorch 2.1版本)

分布式场景下,第三方框架(如DeepSpeed)对原生allreduce API的封装使其无法入图,TorchAir可同时Patch框架中的allreduce封装函数,以解决入图问题。