torch_npu.npu_moe_distribute_dispatch
功能描述
- 算子功能:需与torch_npu.npu_moe_distribute_combine配套使用,完成MoE的并行部署下的token dispatch与combine。对Token数据先进行量化(可选),再进行EP(Expert Parallelism)域的alltoallv通信,再进行TP(Tensor Parallelism)域的allgatherv通信(可选)。
- 计算公式:
接口原型
torch_npu.npu_moe_distribute_dispatch(Tensor x, Tensor expert_ids, str group_ep, int ep_world_size, int ep_rank_id, int moe_expert_num, *, Tensor? scales=None, Tensor? x_active_mask=None, Tensor? expert_scales=None, str group_tp="", int tp_world_size=0, int tp_rank_id=0, int expert_shard_type=0, int shared_expert_num=1, int shared_expert_rank_num=0, int quant_mode=0, int global_bs=0, int expert_token_nums_type=1) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
参数说明

参数里Shape使用的变量如下:
- A:表示本卡接收的最大token数量,取值范围如下
- 对于共享专家,要满足A=BS*ep_world_size*shared_expert_num/shared_expert_rank_num。
- 对于MoE专家,当global_bs为0时,要满足A>=BS*ep_world_size*min(local_expert_num, K);当global_bs非0时,要满足A>=global_bs* min(local_expert_num, K)。
- H:表示hidden size隐藏层大小。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :取值范围(0,7168],且保证是32的整数倍。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :取值为7168。
- BS:表示待发送的token数量。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :取值范围为0<BS≤256。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :取值范围为0<BS≤512。
- K:表示选取topK个专家,取值范围为0<K≤8同时满足0<K≤moe_expert_num。
- server_num:表示服务器的节点数,取值只支持2、4、8。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :仅该场景的shape使用了该变量。
- local_expert_num:表示本卡专家数量。
- 对于共享专家卡,local_expert_num=1
- 对于MoE专家卡,local_expert_num=moe_expert_num/(ep_world_size-shared_expert_rank_num),当local_expert_num>1时,不支持TP域通信。
- x:Tensor类型,表示计算使用的token数据,需根据expert_ids来发送给其他卡。要求为2D的Tensor,shape为(BS, H),表示有BS个token,数据类型支持bfloat16、float16,数据格式为ND,支持非连续的Tensor。
- expert_ids:Tensor类型,表示每个token的topK个专家索引,决定每个token要发给哪些专家。要求为2D的Tensor,shape为(BS, K),数据类型支持int32,数据格式为ND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine的expert_ids输入,张量里value取值范围为[0, moe_expert_num),且同一行中的K个value不能重复。
- group_ep:string类型,EP通信域名称,专家并行的通信域。字符串长度范围为[1,128),不能和group_tp相同。
- ep_world_size:int类型,EP通信域size。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :取值支持16、32、64。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :取值支持8、16、32、64、128、144、256、288。
- ep_rank_id:int类型,EP通信域本卡ID,取值范围[0, ep_world_size),同一个EP通信域中各卡的ep_rank_id不重复。
- moe_expert_num:int类型,MoE专家数量,取值范围[1, 512],并且满足moe_expert_num%(ep_world_size-shared_expert_rank_num)=0。
- scales:Tensor类型,可选参数,表示每个专家的权重,非量化场景不传,动态量化场景可传可不传。若传值要求为2D的Tensor,shape为(shared_expert_num+moe_expert_num, H),数据类型支持float,数据格式为ND,不支持非连续的Tensor。
- x_active_mask:Tensor类型,预留参数,暂未使用,使用默认值即可。
- expert_scales:Tensor类型,可选参数,表示每个token的topK个专家权重。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :要求为2D的Tensor,shape为(BS, K),数据类型支持float,数据格式为ND,支持非连续的Tensor。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :暂不支持该参数,使用默认值即可。
- group_tp:string类型,可选参数,TP通信域名称,数据并行的通信域。若有TP域通信需要传参,若无TP域通信,使用默认值即可。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :不支持TP域通信,使用默认值即可。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :当有TP域通信时,字符串长度范围为[1, 128),不能和group_ep相同。
- tp_world_size:int类型,可选参数,TP通信域size。有TP域通信才需要传参。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :不支持TP域通信,使用默认值0即可。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :当有TP域通信时,取值范围[0, 2],0和1表示无TP域通信,2表示有TP域通信。
- tp_rank_id:int类型,可选参数,TP通信域本卡ID。有TP域通信才需要传参。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :不支持TP域通信,使用默认值即可。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :当有TP域通信时,取值范围[0, 1],默认为0,同一个TP通信域中各卡的tp_rank_id不重复。无TP域通信时,传0即可。
- expert_shard_type:int类型,表示共享专家卡排布类型。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :暂不支持该参数,使用默认值即可。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :当前仅支持0,表示共享专家卡排在MoE专家卡前面。
- shared_expert_num:int类型,表示共享专家数量,一个共享专家可以复制部署到多个卡上。预留参数,暂未使用,使用默认值即可。
- shared_expert_rank_num:int类型,可选参数,表示共享专家卡数量。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :不支持共享专家,传0即可。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :取值范围[0, ep_world_size-1)。取0表示无共享专家,不取0需满足ep_world_size%shared_expert_rank_num=0。
- quant_mode:int类型,可选参数,表示量化模式。支持取值:0表示非量化(默认),2表示动态量化。当quant_mode=2,dynamic_scales不为None;当quant_mode=0,dynamic_scales为None。
- global_bs:int类型,可选参数,表示EP域全局的batch size大小。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :当每个rank的BS不同时,传入256*ep_world_size;当每个rank的BS相同时,支持取值0或BS*ep_world_size。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :支持取0或BS*ep_world_size。
- expert_token_nums_type:int类型,可选参数,表示输出expert_token_nums的值类型,取值范围[0, 1],0表示每个专家收到token数量的前缀和,1表示每个专家收到的token数量(默认)。
输出说明
- expand_x:Tensor类型,表示本卡收到的token数据,要求为2D的Tensor,shape为(max(tp_world_size, 1) *A, H),A表示在EP通信域可能收到的最大token数,数据类型支持bfloat16、float16、int8。量化时类型为int8,非量化时与x数据类型保持一致。数据格式为ND,支持非连续的Tensor。
- dynamic_scales:Tensor类型,表示计算得到的动态量化参数。当quant_mode非0时才有该输出,要求为1D的Tensor,shape为(A,),数据类型支持float,数据格式支持ND,支持非连续的Tensor。
- expand_idx :Tensor类型,表示给同一专家发送的token个数,要求是一个1D的Tensor,shape为(BS*K, )。数据类型支持int32,数据格式为ND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine的expand_idx输入。
- expert_token_nums:Tensor类型,本卡每个专家实际收到的token数量,要求为1D的Tensor,shape为(local_expert_num,),数据类型int64,数据格式支持ND,支持非连续的Tensor。
- ep_recv_counts:Tensor类型,表示EP通信域各卡收到的token数量,要求为1D的Tensor,数据类型int32,数据格式支持ND,支持非连续的Tensor。对应torch_npu.npu_moe_distribute_combine的ep_send_counts输入。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :要求shape为(moe_expert_num+2*global_bs*K*server_num, )。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :要求shape为(ep_world_size*max(tp_world_size, 1)*local_expert_num, )。
- tp_recv_counts:Tensor类型,表示TP通信域各卡收到的token数量。对应torch_npu.npu_moe_distribute_combine的tp_send_counts输入。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :不支持TP通信域,暂无该输出,Atlas A3 训练系列产品/Atlas A3 推理系列产品 :支持TP通信域,要求是一个1D Tensor,shape为(tp_world_size, ),数据类型支持int32,数据格式为ND,支持非连续的Tensor。
- expand_scales:Tensor类型,表示expert_scales与x一起进行alltoallv之后的输出。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :要求是一个1D的Tensor,shape为(A, ),数据类型支持float,数据格式要求为ND,支持非连续的Tensor。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :暂不支持该输出,返回None。
约束说明
- 该接口支持推理场景下使用。
- 该接口支持静态图模式(PyTorch 2.1版本)。
- 调用接口过程中使用的group_ep、ep_world_size、moe_expert_num、group_tp、tp_world_size、expert_shard_type、shared_expert_num、shared_expert_rank_num、global_bs参数取值所有卡保持一致,且和torch_npu.npu_moe_distribute_combine对应参数也保持一致。
Atlas A3 训练系列产品/Atlas A3 推理系列产品 :该场景下单卡包含双DIE(简称为“晶粒”或“裸片”),因此参数说明里的“本卡”均表示单DIE。- 调用本接口前需检查HCCL_BUFFSIZE环境变量取值是否合理:
CANN环境变量HCCL_BUFFSIZE:表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :要求>=2*(BS*ep_world_size*min(local_expert_num, K)*H*sizeof(unit16)+2MB)。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :要求大于等于2且满足1024*1024*(HCCL_BUFFSIZE-2)/2>=(BS*2*(H+128)*(ep_world_size*local_expert_num+K+1)。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :当K=8且BS<=128时,配置环境变量HCCL_INTRA_PCIE_ENABLE=1和HCCL_INTRA_ROCE_ENABLE=0可以减少跨机通信数据量,提升算子性能。此时要求HCCL_BUFFSIZE>=moe_expert_num*BS*(H*sizeof(dtypeX)+4*K*sizeof(uint32))+4MB+100MB。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
Atlas A3 训练系列产品/Atlas A3 推理系列产品
调用示例
- 单算子模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
import os import torch import random import torch_npu import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die sharedExpertRankNum = 2 # 共享专家数 moeExpertNum = 14 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - sharedExpertRankNum local_moe_expert_num = moeExpertNum // moe_rank_num globalBS = bs * ep_world_size is_shared = (sharedExpertRankNum > 0) is_quant = (quant_mode > 0) def gen_unique_topk_array(low, high, bs, k): array = [] for i in range(bs): top_idx = list(np.arange(low, high, dtype=np.int32)) random.shuffle(top_idx) array.append(top_idx[0:k]) return np.array(array) def get_new_group(rank): for i in range(tp_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]] ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = gen_unique_topk_array(0, moeExpertNum, bs, k).astype(np.int32) expert_ids = torch.from_numpy(expert_ids).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum else (moeExpertNum, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales = torch_npu.npu_moe_distribute_dispatch( x=x, expert_ids=expert_ids, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=sharedExpertRankNum, moe_expert_num=moeExpertNum, scales=scales, quant_mode=quant_mode, global_bs=globalBS) if is_quant: expand_x = expand_x.to(input_dtype) x = torch_npu.npu_moe_distribute_combine(expand_x=expand_x, expert_ids=expert_ids, expand_idx=expand_idx, ep_send_counts=ep_recv_counts, tp_send_counts=tp_recv_counts, expert_scales=expert_scales, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=sharedExpertRankNum, moe_expert_num=moeExpertNum, global_bs=globalBS) print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={sharedExpertRankNum}") print(f"moe_expert_num={moeExpertNum}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if sharedExpertRankNum > ep_world_size: print("sharedExpertRankNum 不能大于 ep_world_size") exit(0) if sharedExpertRankNum > 0 and ep_world_size % sharedExpertRankNum != 0: print("ep_world_size 必须是 sharedExpertRankNum的整数倍") exit(0) if moeExpertNum % moe_rank_num != 0: print("moeExpertNum 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.")
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
# 仅支持静态图 import os import torch import random import torch_npu import torchair import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die sharedExpertRankNum = 2 # 共享专家数 moeExpertNum = 14 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - sharedExpertRankNum local_moe_expert_num = moeExpertNum // moe_rank_num globalBS = bs * ep_world_size is_shared = (sharedExpertRankNum > 0) is_quant = (quant_mode > 0) class MOE_DISTRIBUTE_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, expert_ids, group_ep, group_tp, ep_world_size, tp_world_size, ep_rank_id, tp_rank_id, expert_shard_type, shared_expert_rank_num, moe_expert_num, scales, quant_mode, global_bs, expert_scales): output_dispatch_npu = torch_npu.npu_moe_distribute_dispatch(x=x, expert_ids=expert_ids, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, scales=scales, quant_mode=quant_mode, global_bs=global_bs) expand_x_npu, _, expand_idx_npu, _, ep_recv_counts_npu, tp_recv_counts_npu, expand_scales = output_dispatch_npu if expand_x_npu.dtype == torch.int8: expand_x_npu = expand_x_npu.to(input_dtype) output_combine_npu = torch_npu.npu_moe_distribute_combine(expand_x=expand_x_npu, expert_ids=expert_ids, expand_idx=expand_idx_npu, ep_send_counts=ep_recv_counts_npu, tp_send_counts=tp_recv_counts_npu, expert_scales=expert_scales, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, global_bs=global_bs) x = output_combine_npu x_combine_res = output_combine_npu return [x_combine_res, output_combine_npu] def gen_unique_topk_array(low, high, bs, k): array = [] for i in range(bs): top_idx = list(np.arange(low, high, dtype=np.int32)) random.shuffle(top_idx) array.append(top_idx[0:k]) return np.array(array) def get_new_group(rank): for i in range(tp_world_size): ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = gen_unique_topk_array(0, moeExpertNum, bs, k).astype(np.int32) expert_ids = torch.from_numpy(expert_ids).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum else (moeExpertNum, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None model = MOE_DISTRIBUTE_GRAPH_Model() model = model.npu() npu_backend = torchair.get_npu_backend() model = torch.compile(model, backend=npu_backend, dynamic=False) output = model.forward(x, expert_ids, ep_hcomm_info, tp_hcomm_info, ep_world_size, tp_world_size, rank // tp_world_size,rank % tp_world_size, 0, sharedExpertRankNum, moeExpertNum, scales, quant_mode, globalBS, expert_scales) torch.npu.synchronize() print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={sharedExpertRankNum}") print(f"moe_expert_num={moeExpertNum}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if sharedExpertRankNum > ep_world_size: print("sharedExpertRankNum 不能大于 ep_world_size") exit(0) if sharedExpertRankNum > 0 and ep_world_size % sharedExpertRankNum != 0: print("ep_world_size 必须是 sharedExpertRankNum的整数倍") exit(0) if moeExpertNum % moe_rank_num != 0: print("moeExpertNum 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.")
父主题: torch_npu