torch_npu.npu_mm_all_reduce_base

功能描述

TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。

接口原型

npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor

参数说明

输出说明

Tensor类型,数据类型非量化场景以及伪量化场景与x1保持一致,全量化场景为float16或者bfloat16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。

约束说明

支持的PyTorch版本

支持的型号

Atlas A2 训练系列产品

调用示例

import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
    torch_npu.npu.set_device(rank)
    init_method = 'tcp://' + master_ip + ':' + master_port
    dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
    from torch.distributed.distributed_c10d import _get_default_group
    default_pg = _get_default_group()
    if torch.__version__ > '2.0.1':
        hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
    else:
        hcom_info = default_pg.get_hccl_comm_name(rank)

    input_ = torch.randn(x1_shape, dtype=dtype).npu()
    weight = torch.randn(x2_shape, dtype=dtype).npu()
    output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum')
    print("output: ", output)

if __name__ == "__main__":
    worksize = 8
    master_ip = '127.0.0.1'
    master_port = '50001'
    x1_shape = [128, 512]
    x2_shape = [512, 64]
    dtype = torch.float16

    mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)