SwiGlu算子使用指南
算子名称 |
SwiGlu |
---|---|
torch_npu API接口 |
torch_npu. npu_swiglu(x,dim) |
支持的torch_npu版本 |
1.11, 2.0, 2.1 |
支持的昇腾产品 |
Atlas A2 训练系列产品 |
支持的数据类型 |
float16,bfloat16,float |
算子IR及torch_npu接口参数
- 算子IR:
REG_OP(SwiGlu) .INPUT(x, "T") .OUTPUT(y, "T") .DATATYPE(T, TensorType({DT_BF16, DT_FLOAT16, DT_FLOAT})) .ATTR(dim, Int, -1) .OP_END_FACTORY_REG(SwiGlu)
模型中替换代码及算子计算逻辑
- SwiGlu算子常见于LLaMA、LLaMA2、Baichuan等LLM模型中,由于torch侧没有提供SwiGlu算子的接口,因此在模型中通常是以自定义类的形式出现,在forward函数下定义计算逻辑,例如:
class SwiGlu(torch.nn.Module): def __init__(self, dim: = -1): """ Initialize the SwiGlu. Args: dim (int): The dimension of the input tensor. dim(int, optional): The splitting dimension of input tensor. Default is -1. Attributes: dim(int): The splitting dimension of input tensor. """ super().__init__() self.dim= dim def _swiglu(self, x): """ Apply the SwiGlu to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ x = torch.chunk(x, 2, -1) return torch.nn.funtional.silu(x[0])*x[1] def forward(self, x): """ Forward pass through the SwiGlu. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying SwiGlu. """ output = self._swiglu(x) return output
用torch_npu的接口替换forward函数下的所有内容,替换如下:
import torch_npu class SwiGlu(torch.nn.Module): def __init__(self, dim: = -1): """ Initialize the SwiGlu. Args: dim (int): The dimension of the input tensor. dim(int, optional): The splitting dimension of input tensor. Default is -1. Attributes: dim(int): The splitting dimension of input tensor """ super().__init__() self.dim= dim def forward(self, x): """ Forward pass through the SwiGlu. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying SwiGlu. """ dim = -1 return torch_npu.npu_swiglu(x, dim = dim)
- 计算流程图为:
图1 流程图
算子替换的模型中小算子
使用限制
当前仅支持Atlas A2 训练系列产品。
父主题: 融合算子调优