功能描述
- 算子功能:快速高斯误差线性单元激活函数(Fast Gaussian Error Linear Units activation function),对输入的每个元素计算FastGelu的前向结果。
- 计算公式
- 公式1:

该公式仅支持:
- 公式2:

该公式仅支持:
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
- Atlas A3 训练系列产品/Atlas A3 推理系列产品
接口原型
| torch_npu.npu_fast_gelu(Tensor input) -> Tensor
|
参数说明
input:Tensor类型,即公式中的x。数据格式支持ND,支持非连续的Tensor。输入最大支持8维。
- Atlas 训练系列产品:数据类型支持float16、float32。
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件:数据类型支持float16、float32、bfloat16。
- Atlas A3 训练系列产品/Atlas A3 推理系列产品:数据类型支持float16、float32、bfloat16。
- Atlas 推理系列产品:数据类型仅支持float16、float32。
输出说明
一个Tensor类型的输出,代表fast_gelu的计算结果。
约束说明
- 该接口支持推理、训练场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- input输入不能为None。
支持的型号
- Atlas 训练系列产品
- Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件
- Atlas A3 训练系列产品/Atlas A3 推理系列产品
- Atlas 推理系列产品
调用示例
- 单算子调用
| import os
import torch
import torch_npu
import numpy as np
data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
x = torch.from_numpy(data_var).to(torch.float32).npu()
y = torch_npu.npu_fast_gelu(x).cpu().numpy()
|
- 图模式调用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 | import os
import torch
import torch_npu
import numpy as np
import torch.nn as nn
import torchair as tng
from torchair.configs.compiler_config import CompilerConfig
os.environ["ENABLE_ACLNN"] = "false"
torch_npu.npu.set_compile_mode(jit_compile=True)
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
def forward(self, x):
y = torch_npu.npu_fast_gelu(x)
return y
npu_mode = Network()
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)
npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
data_var = np.random.uniform(0, 1, [4, 2048, 16, 128]).astype(np.float32)
x = torch.from_numpy(data_var).to(torch.float32)
y =npu_mode(x).cpu().numpy()
|