import torch
import torch_npu
import random
import copy
import math
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
config = CompilerConfig()
config.experimental_config.keep_inference_input_mutations = True
npu_backend = tng.get_npu_backend(compiler_config=config)
tokens_num = 16384
tokens_length = 7168
rank_num = 16
expert_num = 16
tokens = torch.randint(low=-10, high = 20, size=(tokens_num, tokens_length), dtype=torch.int8)
expert_token_num_per_rank = torch.ones(rank_num, expert_num, dtype = torch.int32)
tokens_sum = 0
for i in range(rank_num):
for j in range(expert_num):
if i == rank_num - 1 and j == expert_num - 1:
expert_token_num_per_rank[i][j] = tokens_num - tokens_sum
break
if tokens_num >= rank_num * expert_num :
floor = math.floor(tokens_num / (rank_num * expert_num))
rand_num = random.randint(1, floor)
elif tokens_sum >= tokens_num:
rand_num = 0
else:
rand_int = tokens_num - tokens_sum
rand_num = random.randint(0, rand_int)
rand_num = 1
expert_token_num_per_rank[i][j] = rand_num
tokens_sum += rand_num
per_token_scales = torch.randn(tokens_num, dtype = torch.float32)
expert_token_num_type = 1
idx_type = 0
tokens_npu = copy.deepcopy(tokens).npu()
per_token_scales_npu = copy.deepcopy(per_token_scales).npu()
expert_token_num_per_rank_npu = copy.deepcopy(expert_token_num_per_rank).npu()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tokens, expert_token_num_per_rank, per_token_scales=None, expert_token_num_type=1, idx_type=0):
permute_tokens, permute_per_token_scales, permute_token_idx, expert_token_num = torch_npu.npu_moe_re_routing(tokens, expert_token_num_per_rank, per_token_scales=per_token_scales, expert_token_num_type=expert_token_num_type, idx_type=idx_type)
return permute_tokens, permute_per_token_scales, permute_token_idx, expert_token_num
model = Model().npu()
model = torch.compile(model, backend=npu_backend, dynamic=False)
permute_tokens_npu, permute_per_token_scales_npu, permute_token_idx_npu, expert_token_num_npu = model(tokens_npu, expert_token_num_per_rank_npu, per_token_scales_npu, expert_token_num_type, idx_type)