torch_npu.npu_moe_distribute_dispatch(Tensor x, Tensor expert_ids, str group_ep, int ep_world_size, int ep_rank_id, int moe_expert_num, *, Tensor? scales=None, Tensor? x_active_mask=None, Tensor? expert_scales=None, str group_tp="", int tp_world_size=0, int tp_rank_id=0, int expert_shard_type=0, int shared_expert_num=1, int shared_expert_rank_num=0, int quant_mode=0, int global_bs=0, int expert_token_nums_type=1) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
参数里Shape使用的变量如下:
CANN环境变量HCCL_BUFFSIZE:表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import os import torch import random import torch_npu import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die sharedExpertRankNum = 2 # 共享专家数 moeExpertNum = 14 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - sharedExpertRankNum local_moe_expert_num = moeExpertNum // moe_rank_num globalBS = bs * ep_world_size is_shared = (sharedExpertRankNum > 0) is_quant = (quant_mode > 0) def gen_unique_topk_array(low, high, bs, k): array = [] for i in range(bs): top_idx = list(np.arange(low, high, dtype=np.int32)) random.shuffle(top_idx) array.append(top_idx[0:k]) return np.array(array) def get_new_group(rank): for i in range(tp_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]] ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): # 如果tp_world_size = 2,ep_world_size = 8,则为[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13], [14, 15]] tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = gen_unique_topk_array(0, moeExpertNum, bs, k).astype(np.int32) expert_ids = torch.from_numpy(expert_ids).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum else (moeExpertNum, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales = torch_npu.npu_moe_distribute_dispatch( x=x, expert_ids=expert_ids, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=sharedExpertRankNum, moe_expert_num=moeExpertNum, scales=scales, quant_mode=quant_mode, global_bs=globalBS) if is_quant: expand_x = expand_x.to(input_dtype) x = torch_npu.npu_moe_distribute_combine(expand_x=expand_x, expert_ids=expert_ids, expand_idx=expand_idx, ep_send_counts=ep_recv_counts, tp_send_counts=tp_recv_counts, expert_scales=expert_scales, group_ep=ep_hcomm_info, group_tp=tp_hcomm_info, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=rank // tp_world_size, tp_rank_id=rank % tp_world_size, expert_shard_type=0, shared_expert_rank_num=sharedExpertRankNum, moe_expert_num=moeExpertNum, global_bs=globalBS) print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={sharedExpertRankNum}") print(f"moe_expert_num={moeExpertNum}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if sharedExpertRankNum > ep_world_size: print("sharedExpertRankNum 不能大于 ep_world_size") exit(0) if sharedExpertRankNum > 0 and ep_world_size % sharedExpertRankNum != 0: print("ep_world_size 必须是 sharedExpertRankNum的整数倍") exit(0) if moeExpertNum % moe_rank_num != 0: print("moeExpertNum 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.") |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | # 仅支持静态图 import os import torch import random import torch_npu import torchair import numpy as np from torch.multiprocessing import Process import torch.distributed as dist from torch.distributed import ReduceOp # 控制模式 quant_mode = 2 # 2为动态量化 is_dispatch_scales = True # 动态量化可选择是否传scales input_dtype = torch.bfloat16 # 输出dtype server_num = 1 server_index = 0 port = 50001 master_ip = '127.0.0.1' dev_num = 16 world_size = server_num * dev_num rank_per_dev = int(world_size / server_num) # 每个host有几个die sharedExpertRankNum = 2 # 共享专家数 moeExpertNum = 14 # moe专家数 bs = 8 # token数量 h = 7168 # 每个token的长度 k = 8 random_seed = 0 tp_world_size = 1 ep_world_size = int(world_size / tp_world_size) moe_rank_num = ep_world_size - sharedExpertRankNum local_moe_expert_num = moeExpertNum // moe_rank_num globalBS = bs * ep_world_size is_shared = (sharedExpertRankNum > 0) is_quant = (quant_mode > 0) class MOE_DISTRIBUTE_GRAPH_Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, expert_ids, group_ep, group_tp, ep_world_size, tp_world_size, ep_rank_id, tp_rank_id, expert_shard_type, shared_expert_rank_num, moe_expert_num, scales, quant_mode, global_bs, expert_scales): output_dispatch_npu = torch_npu.npu_moe_distribute_dispatch(x=x, expert_ids=expert_ids, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, scales=scales, quant_mode=quant_mode, global_bs=global_bs) expand_x_npu, _, expand_idx_npu, _, ep_recv_counts_npu, tp_recv_counts_npu, expand_scales = output_dispatch_npu if expand_x_npu.dtype == torch.int8: expand_x_npu = expand_x_npu.to(input_dtype) output_combine_npu = torch_npu.npu_moe_distribute_combine(expand_x=expand_x_npu, expert_ids=expert_ids, expand_idx=expand_idx_npu, ep_send_counts=ep_recv_counts_npu, tp_send_counts=tp_recv_counts_npu, expert_scales=expert_scales, group_ep=group_ep, group_tp=group_tp, ep_world_size=ep_world_size, tp_world_size=tp_world_size, ep_rank_id=ep_rank_id, tp_rank_id=tp_rank_id, expert_shard_type=expert_shard_type, shared_expert_rank_num=shared_expert_rank_num, moe_expert_num=moe_expert_num, global_bs=global_bs) x = output_combine_npu x_combine_res = output_combine_npu return [x_combine_res, output_combine_npu] def gen_unique_topk_array(low, high, bs, k): array = [] for i in range(bs): top_idx = list(np.arange(low, high, dtype=np.int32)) random.shuffle(top_idx) array.append(top_idx[0:k]) return np.array(array) def get_new_group(rank): for i in range(tp_world_size): ep_ranks = [x * tp_world_size + i for x in range(ep_world_size)] ep_group = dist.new_group(backend="hccl", ranks=ep_ranks) if rank in ep_ranks: ep_group_t = ep_group print(f"rank:{rank} ep_ranks:{ep_ranks}") for i in range(ep_world_size): tp_ranks = [x + tp_world_size * i for x in range(tp_world_size)] tp_group = dist.new_group(backend="hccl", ranks=tp_ranks) if rank in tp_ranks: tp_group_t = tp_group print(f"rank:{rank} tp_ranks:{tp_ranks}") return ep_group_t, tp_group_t def get_hcomm_info(rank, comm_group): if torch.__version__ > '2.0.1': hcomm_info = comm_group._get_backend(torch.device("npu")).get_hccl_comm_name(rank) else: hcomm_info = comm_group.get_hccl_comm_name(rank) return hcomm_info def run_npu_process(rank): torch_npu.npu.set_device(rank) rank = rank + 16 * server_index dist.init_process_group(backend='hccl', rank=rank, world_size=world_size, init_method=f'tcp://{master_ip}:{port}') ep_group, tp_group = get_new_group(rank) ep_hcomm_info = get_hcomm_info(rank, ep_group) tp_hcomm_info = get_hcomm_info(rank, tp_group) # 创建输入tensor x = torch.randn(bs, h, dtype=input_dtype).npu() expert_ids = gen_unique_topk_array(0, moeExpertNum, bs, k).astype(np.int32) expert_ids = torch.from_numpy(expert_ids).npu() expert_scales = torch.randn(bs, k, dtype=torch.float32).npu() scales_shape = (1 + moeExpertNum, h) if sharedExpertRankNum else (moeExpertNum, h) if is_dispatch_scales: scales = torch.randn(scales_shape, dtype=torch.float32).npu() else: scales = None model = MOE_DISTRIBUTE_GRAPH_Model() model = model.npu() npu_backend = torchair.get_npu_backend() model = torch.compile(model, backend=npu_backend, dynamic=False) output = model.forward(x, expert_ids, ep_hcomm_info, tp_hcomm_info, ep_world_size, tp_world_size, rank // tp_world_size,rank % tp_world_size, 0, sharedExpertRankNum, moeExpertNum, scales, quant_mode, globalBS, expert_scales) torch.npu.synchronize() print(f'rank {rank} epid {rank // tp_world_size} tpid {rank % tp_world_size} npu finished! \n') if __name__ == "__main__": print(f"bs={bs}") print(f"global_bs={globalBS}") print(f"shared_expert_rank_num={sharedExpertRankNum}") print(f"moe_expert_num={moeExpertNum}") print(f"k={k}") print(f"quant_mode={quant_mode}", flush=True) print(f"local_moe_expert_num={local_moe_expert_num}", flush=True) print(f"tp_world_size={tp_world_size}", flush=True) print(f"ep_world_size={ep_world_size}", flush=True) if tp_world_size != 1 and local_moe_expert_num > 1: print("unSupported tp = 2 and local moe > 1") exit(0) if sharedExpertRankNum > ep_world_size: print("sharedExpertRankNum 不能大于 ep_world_size") exit(0) if sharedExpertRankNum > 0 and ep_world_size % sharedExpertRankNum != 0: print("ep_world_size 必须是 sharedExpertRankNum的整数倍") exit(0) if moeExpertNum % moe_rank_num != 0: print("moeExpertNum 必须是 moe_rank_num 的整数倍") exit(0) p_list = [] for rank in range(rank_per_dev): p = Process(target=run_npu_process, args=(rank,)) p_list.append(p) for p in p_list: p.start() for p in p_list: p.join() print("run npu success.") |