集合通信入图能够避免断图和拥有更大的成图范围,从而获得更大的资源调度与融合收益,并在整图层面进行通信与计算并行优化。
原生PyTorch社区对集合通信入图的支持度并不完善,功能一直在增强中。TorchAir针对该现状,采用如下方案增强这块入图能力:
具体入图操作参见使用方法(PyTorch 2.1版本)。目前支持入图的API情况如表1所示,请根据实际情况按需调用。
PyTorch集合通信API |
PyTorch 2.1版本 |
说明 |
---|---|---|
torch.distributed.all_gather |
√ |
接口详细的介绍请参考《API参考》。 |
torch.distributed.all_gather_into_tensor |
√ |
|
torch.distributed.all_reduce |
√ |
|
torch.distributed.all_to_all |
√ |
|
torch.distributed.all_to_all_single |
√ |
|
torch.distributed.broadcast |
√ |
|
torch.distributed.reduce_scatter_tensor |
x |
|
torch_npu.distributed.all_gather_into_tensor_uneven |
x |
|
torch_npu.distributed.reduce_scatter_tensor_uneven |
x |
集合通信算子入图的前提是PyTorch脚本中所有算子能正常以Eager模式运行。
分布式场景下,第三方框架(如DeepSpeed)对原生allreduce API的封装使其无法入图,TorchAir可同时Patch框架中的allreduce封装函数,以解决入图问题。
import os import torch import torch_npu import torchair from torchair.configs.compiler_config import CompilerConfig # 导入patch包 torchair.patch_for_hcom() class AllReduceSingeGroup(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): x = x + y torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM) return x def example(rank, world_size): torch.npu.set_device(rank) torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size) x = torch.ones([2, 2], dtype=torch.int32).to("npu:"+str(rank)) y = torch.ones([2, 2], dtype=torch.int32).to("npu:"+str(rank)) config = CompilerConfig() npu_backend = torchair.get_npu_backend(compiler_config=config) model = torch.compile(AllReduceSingeGroup().to("npu:"+str(rank)), backend=npu_backend, dynamic=False) out = torch.ones([2, 2], dtype=torch.int32).npu() * 2 * world_size ret = model(x, y) assert out.equal(ret) torch.distributed.destroy_process_group() def mp(): world_size = 2 torch.multiprocessing.spawn(example, args=(world_size, ), nprocs=world_size, join=True) if __name__ == '__main__': os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29506" mp()
方案2在Ascend Extension for PyTorch的6.0.RC3版本及后续版本不再演进,未来会废弃,建议使用方案1。
推理脚本中导入patch_for_hcom_allreduce包,该包将PyTorch原生torch.distributed.allreduce API替换成torch.distributed._functional_collectives.all_reduce算子,以实现入图。
1 2 3 | # 导入patch包 import torch_npu import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce |