torch_npu.npu_moe_finalize_routing

功能描述

接口原型

npu_moe_finalize_routing(Tensor expanded_permuted_rows, Tensor skip1, Tensor? skip2, Tensor bias, Tensor scales, Tensor expanded_src_to_dst_row, Tensor export_for_source_row) -> Tensor

参数说明

输出说明

out:Device侧的Tensor类型,最后处理合并MoE FFN的输出结果。

约束说明

尾轴H,K中最大值小于8KB。

支持的PyTorch版本

支持的型号

Atlas A2 训练系列产品

调用示例

import torch
import torch_npu

expert_num = 16
token_len = 10
top_k = 4
num_rows = 50
device =torch.device('npu')
dtype = torch.float32
expanded_permuted_rows = torch.randn((num_rows * top_k, token_len), device=device, dtype=dtype)
skip1 = torch.randn((num_rows, token_len), device=device, dtype=dtype)
skip2_optional = torch.randn((num_rows, token_len), device=device, dtype=dtype)
bias = torch.randn((expert_num, token_len), device=device, dtype=dtype)
scales = torch.randn((num_rows, top_k), device=device, dtype=dtype)
expert_for_source_row = torch.randint(low=0, high=expert_num, size=(num_rows, top_k), device=device, dtype=torch.int32)
expanded_src_to_dst_row = torch.randint(low=0, high=num_rows * top_k, size=(num_rows * top_k,), device=device, dtype=torch.int32)

output = torch_npu.npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row)