算子名称 |
SwiGlu |
---|---|
torch_npu API接口 |
torch_npu. npu_swiglu(x,dim) |
支持的torch_npu版本 |
1.11, 2.0, 2.1 |
支持的昇腾产品 |
Atlas A2 训练系列产品 |
支持的数据类型 |
float16,bfloat16,float |
REG_OP(SwiGlu) .INPUT(x, "T") .OUTPUT(y, "T") .DATATYPE(T, TensorType({DT_BF16, DT_FLOAT16, DT_FLOAT})) .ATTR(dim, Int, -1) .OP_END_FACTORY_REG(SwiGlu)
class SwiGlu(torch.nn.Module): def __init__(self, dim: = -1): """ Initialize the SwiGlu. Args: dim (int): The dimension of the input tensor. dim(int, optional): The splitting dimension of input tensor. Default is -1. Attributes: dim(int): The splitting dimension of input tensor. """ super().__init__() self.dim= dim def _swiglu(self, x): """ Apply the SwiGlu to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ x = torch.chunk(x, 2, -1) return torch.nn.funtional.silu(x[0])*x[1] def forward(self, x): """ Forward pass through the SwiGlu. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying SwiGlu. """ output = self._swiglu(x) return output
用torch_npu的接口替换forward函数下的所有内容,替换如下:
import torch_npu class SwiGlu(torch.nn.Module): def __init__(self, dim: = -1): """ Initialize the SwiGlu. Args: dim (int): The dimension of the input tensor. dim(int, optional): The splitting dimension of input tensor. Default is -1. Attributes: dim(int): The splitting dimension of input tensor """ super().__init__() self.dim= dim def forward(self, x): """ Forward pass through the SwiGlu. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying SwiGlu. """ dim = -1 return torch_npu.npu_swiglu(x, dim = dim)
当前仅支持Atlas A2 训练系列产品。