TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。
npu_mm_all_reduce_base(Tensor self, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, int comm_turn=0) -> Tensor
Tensor类型,数据类型和self保持一致,shape第0维度和self的0维保持一致,shape第1维度和x2的1维保持一致。
输入self可为2维或者3维、x2必须是2维,分别为(s, m, k)/(m, k), (k, n),轴满足mm算子入参要求,k轴相等。bias当前仅支持一维,且维度大小与output的最后一维大小相同。
Atlas A2 训练系列产品
import torch import torch_npu import torch.distributed as dist import torch.multiprocessing as mp def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype): torch_npu.npu.set_device(rank) init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method) from torch.distributed.distributed_c10d import _get_default_group default_pg = _get_default_group() if torch.__version__ > '2.0.1': hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcom_info = default_pg.get_hccl_comm_name(rank) input_ = torch.randn(x1_shape, dtype=dtype).npu() weight = torch.randn(x2_shape, dtype=dtype).npu() output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcomm_info, reduce_op='sum') worksize = 8 master_ip = '127.0.0.1' master_port = '50001' x1_shape = [128, 512] x2_shape = [512, 64] dtype = torch.float16 mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)