昇腾社区首页
中文
注册

torch_npu.npu_ffn

功能说明

  • 算子功能:该FFN算子提供MoeFFN和FFN的计算功能。在没有专家分组(expert_tokens为空)时是FFN,有专家分组时是MoeFFN。

  • 计算公式:

说明:[object Object] 激活层为geglu/swiglu/reglu时,性能使能需要满足门槛要求,即整网中FFN结构所对应的小算子中vector耗时30us且占比10%以上的用例方可尝试FFN融合算子;或在不知道小算子性能的情况下,尝试使能FFN,若性能劣化则不使能FFN。

函数原型

[object Object]

参数说明

  • x (Tensor):输入参数,公式中的xx,数据类型支持float16bfloat16int8,数据格式支持NDND,支持输入的维度最少是2维[M,K1][M, K1],最多是8维。

  • weight1 (Tensor):专家的权重数据,公式中的W1W1,数据类型支持float16bfloat16int8,数据格式支持NDND,输入在有/无专家时分别为[E,K1,N1]/[K1,N1][E, K1, N1]/[K1, N1]

  • weight2 (Tensor):专家的权重数据,公式中的W2W2,数据类型支持float16bfloat16int8,数据格式支持NDND,输入在有/无专家时分别为[E,K2,N2]/[K2,N2][E, K2, N2]/[K2, N2]

    说明:[object Object] MM表示token个数,对应transform中的BS(BB:Batch,表示输入样本批量大小,SS:Seq-Length,表示输入样本序列长度);K1K1表示第一个matmul的输入通道数,对应transform中的HH(Head-Size,表示隐藏层的大小);N1N1表示第一个matmul的输出通道数;K2K2表示第二个matmul的输入通道数;N2N2表示第二个matmul的输出通道数,对应transform中的HHEE表示有专家场景的专家数。

  • activation (str):代表使用的激活函数,即输入参数中的activation。当前仅支持fastgelu、gelu、relu、silu、geglu、swiglu、reglu

  • expert_tokens (list):可选参数。代表各专家的token数,数据类型支持int32,数据格式支持NDND,若不为空时可支持的最大长度为256个。

  • expert_tokens_index (list):可选参数。代表各专家计算token的索引值,数据类型支持int32,数据格式支持NDND,若不为空时可支持的最大长度为256个。

  • bias1 (Tensor):可选参数。权重数据修正值,公式中的b1b1,数据类型支持float16float32int32,数据格式支持NDND,输入在有/无专家时分别为[E,N1]/[N1][E, N1]/[N1]

  • bias2 (Tensor):可选参数。权重数据修正值,公式中的b2b2,数据类型支持float16float32int32,数据格式支持NDND,输入在有/无专家时分别为[E,N2]/[N2][E, N2]/[N2]

  • scale (Tensor):可选参数,量化参数,量化缩放系数,数据类型支持float32,数据格式支持NDND。per-tensor下输入在有/无专家时均为一维向量,输入元素个数在有/无专家时分别为[E]/[1][E]/[1];per-channel下输入在有/无专家时为二维向量/一维向量,输入元素个数在有/无专家时分别为[E,N1]/[N1][E, N1]/[N1]

  • offset (Tensor):可选参数,量化参数,量化偏移量,数据类型支持float32,数据格式支持NDND,一维向量,输入元素个数在有/无专家时分别为[E]/[1][E]/[1]

  • deq_scale1 (Tensor):可选参数,量化参数,第一组matmul的反量化缩放系数,数据类型支持int64float32bfloat16,数据格式支持NDND,输入在有/无专家时分别为[E,N1]/[N1][E, N1]/[N1]

  • deq_scale2 (Tensor):可选参数,量化参数,第二组matmul的反量化缩放系数,数据类型支持int64float32bfloat16,数据格式支持NDND,输入在有/无专家时分别为[E,N2]/[N2][E, N2]/[N2]

  • antiquant_scale1 (Tensor):可选参数,伪量化参数,第一组matmul的缩放系数,数据类型支持float16bfloat16,数据格式支持NDND,per-channel下输入在有/无专家时分别为[E,N1]/[N1][E, N1]/[N1]

  • antiquant_scale2 (Tensor):可选参数,伪量化参数,第二组matmul的缩放系数,数据类型支持float16bfloat16,数据格式支持NDND,per-channel下输入在有/无专家时分别为[E,N2]/[N2][E, N2]/[N2]

  • antiquant_offset1 (Tensor):可选参数,伪量化参数,第一组matmul的偏移量,数据类型支持float16bfloat16,数据格式支持NDND,per-channel下输入在有/无专家时分别为[E,N1]/[N1][E, N1]/[N1]

  • antiquant_offset2 (Tensor):可选参数,伪量化参数,第二组matmul的偏移量,数据类型支持float16bfloat16,数据格式支持NDND,per-channel下输入在有/无专家时分别为[E,N2]/[N2][E, N2]/[N2]

  • inner_precise (int):可选参数,表示高精度或者高性能选择。数据类型支持int64。该参数仅对float16生效,bfloat16int8不区分高精度和高性能。

    • inner_precise为0时,代表开启高精度模式,算子内部采用float32数据类型计算。
    • inner_precise为1时,代表高性能模式。

    inner_precise参数在bfloat16非量化场景,只能配置为0;float16非量化场景,可以配置为0或者1;量化或者伪量化场景,0和1都可配置,但是配置后不生效。

  • output_dtype (ScalarType):可选参数,该参数只在量化场景生效,其他场景不生效。表示输出Tensor的数据类型,支持输入float16bfloat16。默认值为None,代表输出Tensor数据类型为float16

返回值

Tensor

一个Tensor类型的输出,公式中的输出yy,数据类型支持float16bfloat16,数据格式支持NDND,输出维度与x一致。

约束说明

  • 该接口支持推理场景下使用。

  • 该接口支持图模式(PyTorch 2.1版本)。

  • 有专家时,专家数据的总数需要与xMM保持一致。

  • 激活层为geglu/swiglu/reglu时,仅支持无专家分组时的float16高性能场景(float16场景指类型为Tensor的必选参数数据类型都为float16的场景),且N1=2K2N1=2*K2

  • 激活层为gelu/fastgelu/relu/silu时,支持有专家或无专家分组的float16高精度及高性能场景,bfloat16场景,量化场景及伪量化场景,且N1=K2N1=K2

  • 所有场景下需满足K1=N2K1<65536K2<65536MK1=N2、K1<65536、K2<65536、M轴在32Byte对齐后小于int32的最大值。

  • 非量化场景不能输入量化参数和伪量化参数,量化场景不能输入伪量化参数,伪量化场景不能输入量化参数。

  • 量化场景参数类型:xint8weightint8biasint32scalefloat32offsetfloat32,其余参数类型根据y不同分两种情况:

    • yfloat16deq_scale支持数据类型uint64int64float32
    • ybfloat16deq_scale支持数据类型bfloat16
    • 要求deq_scale1deq_scale2的数据类型保持一致。
  • 量化场景支持scale的per-channel模式参数类型:xint8weightint8biasint32scalefloat32offsetfloat32,其余参数类型根据y不同分两种情况:

    • yfloat16deq_scale支持数据类型uint64int64
    • ybfloat16deq_scale支持数据类型bfloat16
    • 要求deq_scale1deq_scale2的数据类型保持一致。
  • 伪量化场景支持两种不同参数类型:

    • yfloat16xfloat16biasfloat16antiquant_scalefloat16antiquant_offsetfloat16weight支持数据类型int8
    • ybfloat16xbfloat16biasfloat32antiquant_scalebfloat16antiquant_offsetbfloat16weight支持数据类型int8

支持的型号

  • [object Object]Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件[object Object]
  • [object Object]Atlas A3 训练系列产品/Atlas A3 推理系列产品[object Object]

调用示例

  • 单算子模式调用

    [object Object]
  • 图模式调用

    [object Object]