功能描述
TP切分场景下,实现mm和all_reduce的融合,融合算子内部实现计算和通信流水并行。
接口原型
npu_mm_all_reduce_base(Tensor x1, Tensor x2, str hcom, *, str reduce_op='sum', Tensor? bias=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? x3=None, Tensor? dequant_scale=None, int comm_turn=0, int antiquant_group_size=0) -> Tensor
参数说明
- x1:Device侧的Tensor类型,支持float16、bfloat16、int8,支持ND,输入shape支持2维或者3维。
- x2:Device侧的Tensor类型,支持float16、bfloat16、int8,支持ND,数据类型需要和x1保持一致,输入shape维度第0维和x1的最后一维保持一致。
- hcom:Host侧的String类型,通信域handle名,通过get_hccl_comm_name接口获取。
- *:代表其之前的变量是位置相关,按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- reduce_op:Host侧的String类型,reduce操作类型,当前版本仅支持'sum',默认值:'sum'。
- bias:Device侧的Tensor类型,可选输入,支持int32、float16、bfloat16,支持ND格式。bias当前仅支持一维,且维度大小与output/x2的最后一维大小相同。
- antiquant_scale:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,支持float16、bfloat16,支持ND格式。数据类型需要和x1保持一致。antiquant_scale当前per-tensor场景shape为[1],per-channel场景支持shape为[1,n]或者[n]。其中n为x2最后一维的大小。伪量化场景当前版本未支持。
- antiquant_offset:Device侧的Tensor类型,可选输入,伪量化场景对x2进行去量化的系数,支持float16、bfloat16,支持ND格式。数据类型需要和antiquant_scale保持一致。shape与antiquant_scale保持一致。伪量化场景当前版本未支持。
- x3:Device侧的Tensor类型,可选输入,matmul计算后的偏移。支持float16、bfloat16。支持ND格式。数据类型需要和x1保持一致。shape与output的shape相同。
- dequant_scale:Device侧的Tensor类型,可选输入,matmul计算后的去量化系数。支持int64、float16、bfloat16,支持ND格式。shape在per-tensor场景为[1],per-channel场景为[n]/[1,n]。当前版本仅支持int64。
- comm_turn:Host侧的int类型,表示rank间通信切分粒度,默认值:0,表示默认的切分方式。当前版本仅支持输入0。
- antiquant_group_size:Host侧的int类型,量化per-group场景使用,对输入x2进行反量化计算的groupSize输入,默认值:0。预留参数,暂未使用。
输出说明
Tensor类型,数据类型非量化场景和x1保持一致,全量化场景为float16或者bfloat16。当前版本仅支持float16。shape第0维度和x1的0维保持一致,若x1为2维,shape第1维度和x2的1维保持一致,若x1为3维,shape第1维度和x1的1维保持一致,shape第2维度和x2的1维保持一致。
约束说明
- 输入x1可为2维或者3维、x2必须是2维,分别为(s, m, k)/(m, k), (k, n),轴满足mm算子入参要求,k轴相等。bias当前仅支持一维,且维度大小与output的最后一维大小相同。x3的shape与output的shape相同。antiquant_scale当前per-tensor场景shape为[1],per-channel场景支持shape为[1,n]或者[n]。antiquant_offset的shape与antiquant_scale一致。dequant_scale的shape在per-tensor场景为[1],per-channel场景为[n]/[1,n]。
- x1、x2不能为空tensor。
- 非量化场景,x1、x2、bias、x3的数据类型保持一致,可为float16或者bfloat16,antiquant_scale、antiquant_offset、dequant_scale为空tensor。
- 伪量化场景,x1、bias、x3、antiquant_scale、antiquant_offset的数据类型保持一致,可为float16或者bfloat16,x2的数据类型为int8,dequant_scale为空tensor。
- 全量化场景,x1、x2的数据类型为int8,dequant_scale的数据类型为int64,bias数据类型为int32,antiquant_scale、antiquant_offset、x3为空tensor。
- 昇腾310P3 AI处理器只支持2卡。
- Atlas A2 训练系列产品支持2、4、8卡。
支持的PyTorch版本
- PyTorch 2.1
- PyTorch 2.0
- PyTorch 1.11.0
支持的型号
- Atlas A2 训练系列产品
- 昇腾310P3 AI处理器
调用示例
import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp
def run_mm_all_reduce_base(rank, world_size, master_ip, master_port, x1_shape, x2_shape, dtype):
torch_npu.npu.set_device(rank)
init_method = 'tcp://' + master_ip + ':' + master_port
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size, init_method=init_method)
from torch.distributed.distributed_c10d import _get_default_group
default_pg = _get_default_group()
if torch.__version__ > '2.0.1':
hcom_info = default_pg._get_backend(torch.device("npu")).get_hccl_comm_name(rank)
else:
hcom_info = default_pg.get_hccl_comm_name(rank)
input_ = torch.randn(x1_shape, dtype=dtype).npu()
weight = torch.randn(x2_shape, dtype=dtype).npu()
output = torch_npu.npu_mm_all_reduce_base(input_, weight, hcom_info, reduce_op='sum')
worksize = 8
master_ip = '127.0.0.1'
master_port = '50001'
x1_shape = [128, 512]
x2_shape = [512, 64]
dtype = torch.float16
mp.spawn(run_mm_all_reduce_base, args=(worksize, master_ip, master_port, x1_shape, x2_shape, dtype), nprocs=worksize)