昇腾社区首页
中文
注册

torch_npu.npu_mm_reduce_scatter_base

功能描述

TP切分场景下,实现matmul和reduce_scatter的融合,融合算子内部实现计算和通信流水并行。

使用该接口时,请确保驱动固件包和CANN包都为配套的8.0.RC2版本或者配套的更高版本,否则将会引发报错,比如BUS ERROR等。

接口原型

npu_mm_reduce_scatter_base(Tensor input, Tensor x2, str hcom, int world_size, *, str reduce_op='sum', Tensor? bias=None, int comm_turn=0) -> Tensor

参数说明

  • input:Device侧的Tensor类型,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,输入shape支持2维。
  • x2:Device侧的Tensor类型,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND,数据类型需要和input保持一致,输入shape维度和input保持一致。
  • hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
  • world_size:Host侧的int类型,通信域内的rank总数,仅支持为2、4、8。
  • *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
  • reduce_op:Host侧的String类型,reduce操作类型,当前仅支持'sum',默认值:'sum'。
  • bias:Device侧的Tensor类型,可选输入,数据类型支持FLOAT16、BFLOAT16,数据格式支持ND格式。数据类型需要和input保持一致。bias仅支持一维,且维度大小与output的第1维大小相同。当前版本暂不支持bias输入为非0的场景。
  • comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。

输出说明

Tensor类型,数据类型和input保持一致,shape维度和input保持一致。

约束说明

  • 该接口仅在训练场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • 输入input、x2必须是2维,分别为(m, k),(k, n),轴满足matmul算子入参要求,k轴相等,且k轴取值范围为[256, 65535),m轴约束如下:

    m轴需要整除world_size,支持2、4、8卡,且仅支持hccs链路all mesh组网。

  • input不支持输入转置后的tensor,x2转置后输入,需要满足shape的第一维大小与input的最后一维相同,满足matmul的计算条件。
  • 一个模型中的通算融合算子(AllGatherMatmul、MatmulReduceScatter、MatmulAllReduce),仅支持相同通信域。

支持的型号

Atlas A2 训练系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = default_pg.get_hccl_comm_name(rank)
    
        input_ = torch.randn(x1_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        output = torch_npu.npu_mm_reduce_scatter_base(input_, weight, hcomm_info, world_size)
    
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
    
        mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    import torch
    import torch_npu
    import torch.distributed as dist
    import torch.multiprocessing as mp
    class MM_REDUCESCATTER_GRAPH_Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self, input, weight, hcomm_info, world_size, reduce_op):
            output = torch_npu.npu_mm_reduce_scatter_base(input, weight, hcomm_info, world_size,
                                                          reduce_op=reduce_op)
            return output
    def define_model(model, graph_type):
        import torchair
        if graph_type == 1:  # 传统入图模式,静态shape+在线编译场景
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=False)
        elif graph_type == 2:  # ACLNN入图模式,动态shape+二进制
            npu_backend = torchair.get_npu_backend(compiler_config=None)
            model = torch.compile(model, backend=npu_backend, dynamic=True)
        else:
            print("Error type")
        return model
    def get_graph(input, weight, hcomm_info, world_size):
        model = MM_REDUCESCATTER_GRAPH_Model()
        model = define_model(model, 2)
        model_output = model(input, weight, hcomm_info, world_size, reduce_op="sum")
        return model_output
    def run_mm_reduce_scatter_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
        torch_npu.npu.set_device(rank)
        init_method = 'tcp://' + master_ip + ':' + master_port
        dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
        from torch.distributed.distributed_c10d import _get_default_group
        default_pg = _get_default_group()
        if torch.__version__ > '2.0.1':
            hcomm_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
        else:
            hcomm_info = default_pg.get_hccl_comm_name(rank)
        input = torch.randn(x1_shape, dtype=dtype).npu()
        weight = torch.randn(x2_shape, dtype=dtype).npu()
        output = get_graph(input, weight, hcomm_info, world_size)
        print("output:", output)
    if __name__ == "__main__":
        worksize = 8
        master_ip = '127.0.0.1'
        master_port = '50001'
        x1_shape = [128, 512]
        x2_shape = [512, 64]
        dtype = torch.float16
        mp.spawn(run_mm_reduce_scatter_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)