昇腾社区首页
中文
注册

集合通信入图

功能简介

集合通信入图能够避免断图和拥有更大的成图范围,从而获得更大的资源调度与融合收益,并在整图层面进行通信与计算并行优化。

原生PyTorch社区对集合通信入图的支持度并不完善,功能一直在增强中。TorchAir针对该现状,采用如下方案增强这块入图能力:

  • PyTorch 2.1版本:以Monkey Patch方式接入Ascend IR计算图,补齐原生集合通信不支持入图的现状。

具体入图操作参见使用方法(PyTorch 2.1版本)。目前支持入图的API情况如表1所示,请根据实际情况按需调用。

表1 集合通信API入图支持情况

PyTorch集合通信API

PyTorch 2.1版本

说明

torch.distributed.all_gather

接口详细的介绍请参考API参考

torch.distributed.all_gather_into_tensor

torch.distributed.all_reduce

torch.distributed.all_to_all

torch.distributed.all_to_all_single

torch.distributed.broadcast

torch.distributed.reduce_scatter_tensor

x

torch_npu.distributed.all_gather_into_tensor_uneven

x

torch_npu.distributed.reduce_scatter_tensor_uneven

x

使用约束

集合通信算子入图的前提是PyTorch脚本中所有算子能正常以Eager模式运行。

使用方法(PyTorch 2.1版本)

分布式场景下,第三方框架(如DeepSpeed)对原生allreduce API的封装使其无法入图,TorchAir可同时Patch框架中的allreduce封装函数,以解决入图问题。

  • 方案1:patch_for_hcom补丁包
    假设推理脚本test.py定义如下,在调用torch.compile前显式调用torchair.patch_for_hcom()即可实现集合通信入图。
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    import os
    import torch
    import torch_npu
    import torchair
    from torchair.configs.compiler_config import CompilerConfig
    
    # 导入patch包
    torchair.patch_for_hcom()
    class AllReduceSingeGroup(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self, x, y):
            x = x + y
            torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM)
            return x
    
    def example(rank, world_size):
        torch.npu.set_device(rank)
        torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
        x = torch.ones([2, 2], dtype=torch.int32).to("npu:"+str(rank))
        y = torch.ones([2, 2], dtype=torch.int32).to("npu:"+str(rank))
        config = CompilerConfig()
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        model = torch.compile(AllReduceSingeGroup().to("npu:"+str(rank)), backend=npu_backend, dynamic=False)
        out = torch.ones([2, 2], dtype=torch.int32).npu() * 2 * world_size
        ret = model(x, y)
        assert out.equal(ret)
        torch.distributed.destroy_process_group()
    
    def mp():
        world_size = 2
        torch.multiprocessing.spawn(example, args=(world_size, ), nprocs=world_size, join=True)
    
    if __name__ == '__main__':
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "29506"    
        mp()
    
  • 方案2(废弃):patch_for_hcom_allreduce补丁包

    方案2在Ascend Extension for PyTorch6.0.RC3版本及后续版本不再演进,未来会废弃,建议使用方案1

    推理脚本中导入patch_for_hcom_allreduce包,该包将PyTorch原生torch.distributed.allreduce API替换成torch.distributed._functional_collectives.all_reduce算子,以实现入图。

    1
    2
    3
    # 导入patch包
    import torch_npu 
    import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce