昇腾社区首页
EN
注册

torch_npu.npu_grouped_matmul_finalize_routing

功能描述

GroupedMatMul和MoeFinalizeRouting的融合算子,GroupedMatMul计算后的输出按照索引做combine动作。

接口原型

1
torch_npu.npu_grouped_matmul_finalize_routing(Tensor x, Tensor w, Tensor group_list, *, Tensor? scale=None, Tensor? bias=None, Tensor? pertoken_scale=None, Tensor? shared_input=None, Tensor? logit=None, Tensor? row_index=None, ScalarType? dtype=None, float? shared_input_weight=1.0, int shared_input_offset=0, int? output_bs=0, int? group_list_type=1) -> Tensor

参数说明

  • x:一个2D的Device侧Tensor输入,矩阵计算的左矩阵,不支持非连续的Tensor。数据类型支持int8,数据格式支持ND,维度为(m, k)。m取值范围为[1, 16*1024*8],k只支持2048。
  • w:一个5D的Device侧Tensor输入,矩阵计算的右矩阵,不支持非连续的Tensor。数据类型支持int8,数据格式支持NZ,维度为(e, n1, k1, k0, n0),其中k0=16、n0=32, x shape中的k和w shape中的k1需要满足以下关系:ceilDiv(k, 16) = k1,e取值范围[1, 256]。
  • group_list: 一个1D的Device侧Tensor输入,GroupedMatMul的各分组大小。不支持非连续的Tensor。数据类型支持int64,数据格式支持ND,维度为(e,),e与w的e一致。group_list的值总和要求≤m。
  • scale:一个2D的Device侧Tensor输入,矩阵计算反量化参数,对应weight矩阵,per-channel量化方式,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND,维度(e, n),这里的n=n1*n0,n只支持7168。
  • bias:一个2D的Device侧Tensor输入,矩阵计算的bias参数,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND。
  • pertoken_scale:一个1D的Device侧Tensor输入,矩阵计算的反量化参数,对应x矩阵,per-token量化方式,不支持非连续的Tensor。维度为(m,),m与x的m一致。数据类型支持float32,数据格式支持ND。
  • shared_input:一个2D的Device侧Tensor输入,MoE计算中共享专家的输出,需要与MoE专家的输出进行combine操作,不支持非连续的Tensor。数据类型支持bfloat16,数据格式支持ND,维度(batch/dp, n),n与scale的n一致,batch/dp取值范围[1, 2*1024],batch取值范围[1, 16*1024]。
  • logit:一个1D的Device侧Tensor输入,MoE专家对各个token的logit大小,矩阵乘的计算输出与该logit做乘法,然后索引进行combine,不支持非连续的Tensor。数据类型支持float32,数据格式支持ND,维度(m,),m与x的m一致。
  • row_index:一个1D的Device侧Tensor输入,MoE专家输出按照该rowIndex进行combine,其中的值即为combine做scatter add的索引,不支持非连续的Tensor。数据类型支持int64,数据格式支持ND,维度为(m,),m与x的m一致。
  • dtype:ScalarType类型,指定GroupedMatMul计算的输出类型。0表示float32,1表示float16,2表示bfloat16。默认值为0。
  • shared_input_weight:float类型,指共享专家与MoE专家进行combine的系数,shared_input先与该参数乘,然后再和MoE专家结果累加。默认为1.0。
  • shared_input_offset:int类型,共享专家输出的在总输出中的偏移。默认值为0。
  • output_bs:int类型,输出的最高维大小。默认值为0。
  • group_list_type:int类型数组,GroupedMatMul的分组模式。默认为1,表示count模式;若配置为0,表示cumsum模式,即为前缀和。

输出说明

y: 一个2D的Tensor,不支持非连续的Tensor,输出的数据类型固定为float32,维度为(batch, n)。

约束说明

  • 该接口在推理和训练场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • 输入和输出Tensor支持的数据类型组合如下:

    x

    w

    group_list

    scale

    bias

    pertoken_scale

    shared_input

    logit

    row_index

    y

    int8

    int8

    int64

    float32

    None

    float32

    bfloat16

    float32

    int64

    float32

    int8

    int8

    int64

    float32

    None

    float32

    None

    None

    int64

    float32

支持的型号

  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

调用示例

  • 单算子模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    import numpy as np
    import torch
    import torch_npu
    import tensorflow as tf
    from scipy.special import softmax
     
    bfloat16 = tf.bfloat16.as_numpy_dtype
    m, k, n = 576, 2048, 7168
    batch = 72
    topK = 8
    group_num = 8
     
    x = np.random.randint(-10, 10, (m, k)).astype(np.int8)
    weight = np.random.randint(-10, 10, (group_num, k, n)).astype(np.int8)
    scale = np.random.normal(0, 0.01, (group_num, n)).astype(np.float32)
    pertoken_scale = np.random.normal(0, 0.01, (m, 1)).astype(np.float32)
    group_list = np.array([batch] * group_num, dtype=np.int64)
    shared_input = np.random.normal(0, 0.1, (batch // 4, n)).astype(np.float32)
    logit_ori = np.random.normal(0, 0.1, (batch, group_num)).astype(np.float32)
    routing = np.argsort(logit_ori, axis=1)[:, -topK:].astype(np.int32)
    logit = softmax(logit_ori[np.arange(batch).reshape(-1, 1).repeat(topK, axis=1), routing], axis=1).astype(np.float32)
    logit = logit.reshape(m)
    row_index = (np.argsort(routing.reshape(-1)) // topK).astype(np.int64)
     
    x_clone = torch.from_numpy(x).npu()
    weight_clone = torch.from_numpy(weight).npu()
    weightNz = torch_npu.npu_format_cast(weight_clone, 29)
    scale_clone = torch.from_numpy(scale).npu()
    pertoken_scale_clone = torch.from_numpy(pertoken_scale).npu()
    group_list_clone = torch.from_numpy(group_list).npu()
    shared_input_clone = torch.from_numpy(shared_input).to(torch.bfloat16).npu()
    logit_clone = torch.from_numpy(logit).npu()
    row_index_clone = torch.from_numpy(row_index).npu()
    shared_input_offset = batch // 2
    output_bs = batch
    y = torch_npu.npu_grouped_matmul_finalize_routing(x_clone, weightNz,
                group_list_clone, scale=scale_clone, pertoken_scale=pertoken_scale_clone,
                shared_input=shared_input_clone, logit=logit_clone, row_index=row_index_clone,
                shared_input_offset=shared_input_offset, output_bs=output_bs)
    
  • 图模式调用:
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    import numpy as np
    import torch
    import torch_npu
    import torchair as tng
    import tensorflow as tf
    from scipy.special import softmax
    from torchair.configs.compiler_config import CompilerConfig
     
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
     
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
        def forward(self, x, weight, group_list, scale, pertoken_scale, shared_input, logit, row_index, shared_input_offset, output_bs):
            output = torch_npu.npu_grouped_matmul_finalize_routing(x, weight, group_list,
                        scale=scale, pertoken_scale=pertoken_scale, shared_input=shared_input,
                        logit=logit, row_index=row_index, shared_input_offset=shared_input_offset, output_bs=output_bs)
            return output
     
    bfloat16 = tf.bfloat16.as_numpy_dtype
    m, k, n = 576, 2048, 7168
    batch = 72
    topK = 8
    group_num = 8
     
    x = np.random.randint(-10, 10, (m, k)).astype(np.int8)
    weight = np.random.randint(-10, 10, (group_num, k, n)).astype(np.int8)
    scale = np.random.normal(0, 0.01, (group_num, n)).astype(np.float32)
    pertoken_scale = np.random.normal(0, 0.01, (m, 1)).astype(np.float32)
    group_list = np.array([batch] * group_num, dtype=np.int64)
    shared_input = np.random.normal(0, 0.1, (batch // 4, n)).astype(np.float32)
    logit_ori = np.random.normal(0, 0.1, (batch, group_num)).astype(np.float32)
    routing = np.argsort(logit_ori, axis=1)[:, -topK:].astype(np.int32)
    logit = softmax(logit_ori[np.arange(batch).reshape(-1, 1).repeat(topK, axis=1), routing], axis=1).astype(np.float32)
    logit = logit.reshape(m)
    row_index = (np.argsort(routing.reshape(-1)) // topK).astype(np.int64)
     
    x_clone = torch.from_numpy(x).npu()
    weight_clone = torch.from_numpy(weight).npu()
    weightNz = torch_npu.npu_format_cast(weight_clone, 29)
    scale_clone = torch.from_numpy(scale).npu()
    pertoken_scale_clone = torch.from_numpy(pertoken_scale).npu()
    group_list_clone = torch.from_numpy(group_list).npu()
    shared_input_clone = torch.from_numpy(shared_input).to(torch.bfloat16).npu()
    logit_clone = torch.from_numpy(logit).npu()
    row_index_clone = torch.from_numpy(row_index).npu()
    shared_input_offset = batch // 2
    output_bs = batch
     
    model = Model().npu()
    model = torch.compile(model, backend=npu_backend, dynamic=False)
    y = model(x_clone, weightNz, group_list_clone, scale_clone, pertoken_scale_clone, shared_input_clone, logit_clone, row_index_clone, shared_input_offset, output_bs)