1. 对输入expertIdx做排序,得出排序的结果sortedExpertIdx和对应的序号sortedRowIdx:
2.以sortedRowIdx做位置映射得出expandedRowIdxOut:
3. 在drop模式下,对sortedExpertIdx的每个专家统计直方图结果,得出expertTokensBeforeCapacityOut:
4. 计算quant结果:
动态quant:
若不输入scale:
若输入scale:
5. 对quantResult取前NUM_ROWS个sortedRowIdx的对应位置的值,得出expandedXOut:
torch_npu.npu_moe_init_routing_v2(Tensor x, Tensor expert_idx, *, Tensor? scale=None, Tensor? offset=None, int active_num=-1, int expert_capacity=-1, int expert_num=-1, int drop_pad_mode=0, int expert_tokens_num_type=0, bool expert_tokens_num_flag=False, int quant_mode=0, int[2] active_expert_range=[], int row_idx_type=0) -> (Tensor, Tensor, Tensor, Tensor)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import torch import torch_npu bs = 1 h = 613 k = 475 active_num = 475 expert_capacity = -1 expert_num = 226 drop_pad_mode = 0 expert_tokens_num_type = 1 expert_tokens_num_flag = True quant_mode = -1 active_expert_range = [23, 35] row_idx_type = 0 x = torch.randn((bs, h), dtype=torch.float32).npu() expert_idx = torch.randint(0, expert_num, (bs, k), dtype=torch.int32).npu() scale = torch.randn((bs,), dtype=torch.float32).npu() offset = None expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = torch_npu.npu_moe_init_routing_v2( x, expert_idx, scale=scale, offset=offset, active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num, drop_pad_mode=drop_pad_mode, expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag, active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | import torch import torch.nn as nn import torch_npu import torchair as tng from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() npu_backend = tng.get_npu_backend(compiler_config=config) class MoeInitRoutingV2Model(nn.Module): def __init__(self): super().__init__() def forward(self, x, expert_idx, *, scale=None, offset=None, active_num=-1, expert_capacity=-1, expert_num=-1, drop_pad_mode=0, expert_tokens_num_type=0, expert_tokens_num_flag=False, quant_mode=0, active_expert_range=0, row_idx_type=0): return torch.ops.npu.npu_moe_init_routing_v2(x, expert_idx, scale=scale, offset=offset, active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num, drop_pad_mode=drop_pad_mode, expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag, active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type) def main(): bs = 1 h = 613 k = 475 active_num = 475 expert_capacity = -1 expert_num = 226 drop_pad_mode = 0 expert_tokens_num_type = 1 expert_tokens_num_flag = True quant_mode = -1 active_expert_range = [23, 35] row_idx_type = 0 x = torch.randn((bs, h), dtype=torch.float32).npu() expert_idx = torch.randint(0, expert_num, (bs, k), dtype=torch.int32).npu() scale = torch.randn((bs,), dtype=torch.float32).npu() offset = None moe_init_routing_v2_model = MoeInitRoutingV2Model().npu() moe_init_routing_v2_model = torch.compile(moe_init_routing_v2_model, backend=npu_backend, dynamic=False) expanded_x, expanded_row_idx, expert_token_cumsum_or_count, expanded_scale = moe_init_routing_v2_model(x, expert_idx, scale=scale, offset=offset, active_num=active_num, expert_capacity=expert_capacity, expert_num=expert_num, drop_pad_mode=drop_pad_mode, expert_tokens_num_type=expert_tokens_num_type, expert_tokens_num_flag=expert_tokens_num_flag, active_expert_range=active_expert_range, quant_mode=quant_mode, row_idx_type=row_idx_type) if __name__ == '__main__': main() |