torch_npu.npu_grouped_matmul_finalize_routing

功能描述

GroupedMatMul和MoeFinalizeRouting的融合算子,GroupedMatMul计算后的输出按照索引做combine动作。

接口原型

1
torch_npu.npu_grouped_matmul_finalize_routing(Tensor x, Tensor w, Tensor group_list, *, Tensor? scale=None, Tensor? bias=None, Tensor? pertoken_scale=None, Tensor? shared_input=None, Tensor? logit=None, Tensor? row_index=None, ScalarType? dtype=None, float? shared_input_weight=1.0, int shared_input_offset=0, int? output_bs=0, int? group_list_type=1) -> Tensor

参数说明

输出说明

y: 一个2D的Tensor,不支持非连续的Tensor,输出的数据类型固定为float32,维度为(batch, n)。

约束说明

支持的型号

调用示例