昇腾社区首页
EN
注册

torch_npu.npu_fast_gelu

功能描述

  • 算子功能:快速高斯误差线性单元激活函数(Fast Gaussian Error Linear Units activation function),对输入的每个元素计算FastGelu的前向结果。
  • 计算公式
    • 公式1:

      该公式仅支持:

      • Atlas 训练系列产品
      • Atlas 推理系列产品
    • 公式2:

      该公式仅支持:

      • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
      • Atlas A3 训练系列产品/Atlas A3 推理系列产品

接口原型

1
torch_npu.npu_fast_gelu(Tensor input) -> Tensor

参数说明

input:Tensor类型,即公式中的x。数据格式支持ND,支持非连续的Tensor。输入最大支持8维。

  • Atlas 训练系列产品:数据类型支持float16、float32。
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float32、bfloat16。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、float32、bfloat16。
  • Atlas 推理系列产品:数据类型仅支持float16、float32。

输出说明

一个Tensor类型的输出,代表fast_gelu的计算结果。

约束说明

  • 该接口支持推理、训练场景下使用。
  • 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
  • input输入不能为None。

支持的型号

  • Atlas 训练系列产品
  • Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • Atlas 推理系列产品

调用示例

  • 单算子调用
    1
    2
    3
    4
    5
    6
    7
    import os
    import torch
    import torch_npu
    import numpy as np
    data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
    x = torch.from_numpy(data_var).to(torch.float32).npu()
    y = torch_npu.npu_fast_gelu(x).cpu().numpy()
    
  • 图模式调用
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    import os
    import torch
    import torch_npu
    import numpy as np
    import torch.nn as nn
    import torchair as tng
    from torchair.configs.compiler_config import CompilerConfig
    
    os.environ["ENABLE_ACLNN"] = "false"
    torch_npu.npu.set_compile_mode(jit_compile=True)
    class Network(nn.Module):
        def __init__(self):
            super(Network, self).__init__()
        def forward(self, x): 
            y = torch_npu.npu_fast_gelu(x)
            return y
            
    npu_mode = Network()
    config = CompilerConfig()
    npu_backend = tng.get_npu_backend(compiler_config=config)
    npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
    data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
    x = torch.from_numpy(data_var).to(torch.float32)
    y =npu_mode(x).cpu().numpy()