torch_npu.npu_moe_distribute_dispatch

功能描述

接口原型

torch_npu.npu_moe_distribute_dispatch(Tensor x, Tensor expert_ids, str group_ep, int ep_world_size, int ep_rank_id, int moe_expert_num, *, Tensor? scales=None, Tensor? x_active_mask=None, Tensor? expert_scales=None, str group_tp="", int tp_world_size=0, int tp_rank_id=0, int expert_shard_type=0, int shared_expert_num=1, int shared_expert_rank_num=0, int quant_mode=0, int global_bs=0, int expert_token_nums_type=1) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)

参数说明

参数里Shape使用的变量如下:

  • A:表示本卡接收的最大token数量,取值范围如下
    • 对于共享专家,要满足A=BS*ep_world_size*shared_expert_num/shared_expert_rank_num。
    • 对于MoE专家,当global_bs为0时,要满足A>=BS*ep_world_size*min(local_expert_num, K);当global_bs非0时,要满足A>=global_bs* min(local_expert_num, K)。
  • H:表示hidden size隐藏层大小。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:取值范围(0,7168],且保证是32的整数倍。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值为7168
  • BS:表示待发送的token数量。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:取值范围为0<BS≤256。
    • Atlas A3 训练系列产品/Atlas A3 推理系列产品:取值范围为0<BS≤512
  • K:表示选取topK个专家,取值范围为0<K≤8同时满足0<K≤moe_expert_num。
  • server_num:表示服务器的节点数,取值只支持2、4、8。
    • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:仅该场景的shape使用了该变量。
  • local_expert_num:表示本卡专家数量。
    • 对于共享专家卡,local_expert_num=1
    • 对于MoE专家卡,local_expert_num=moe_expert_num/(ep_world_size-shared_expert_rank_num),当local_expert_num>1时,不支持TP域通信。

输出说明

约束说明

支持的型号

调用示例