torch_npu.npu_moe_gating_top_k_softmax

功能描述

MoE计算中,对gating的输出做Softmax计算,取topk操作。

接口原型

npu_moe_gating_top_k_softmax(Tensor x, Tensor? finished=None, int k=1) -> (Tensor, Tensor, Tensor)

参数说明

输出说明

约束说明

无。

支持的PyTorch版本

PyTorch 2.1

支持的型号

Atlas A2 训练系列产品

调用示例

import torch
import torch_npu
x = torch.rand((3, 3), dtype=torch.float32).to("npu")
finished = torch.randint(2, size=(3,), dtype=torch.bool).to("npu")
y, expert_idx, row_idx = torch_npu.npu_moe_gating_top_k_softmax(x, finished, k=2)