RmsNorm & RmsNormGrad
算子基础信息
| 
          算子名称  | 
        
          RmsNorm & RmsNormGrad  | 
       
|---|---|
| 
          torch_npu api接口  | 
        
          torch_npu.npu_rms_norm(x, gamma, epsilon)  | 
       
| 
          支持的torch_npu版本  | 
        
          2.1.0, 2.3.1, 2.4.0, 2.5.1  | 
       
| 
          支持的芯片类型  | 
        
          | 
       
| 
          支持的数据类型  | 
        
          float16, bfloat16, float  | 
       
算子IR及torch_npu接口参数
     算子IR:
     
      
       
        
         
      
     
    
           1 2 3 4 5 6 7  | 
          
           REG_OP(RmsNorm) .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) .INPUT(gamma, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF16})) .OUTPUT(rstd, TensorType({DT_FLOAT, DT_FLOAT, DT_FLOAT})) .ATTR(epsilon, Float, 1e-6) .OP_END_FACTORY_REG(RmsNorm)  | 
         
     torch_npu接口:
     
torch_npu.npu_rms_norm(Tensor self, Tensor gamma, float epsilon=1e-06) -> (Tensor, Tensor)
     参数说明:
     
- x:Tensor类型,shape支持1-8维。
 - gamma:Tensor类型,通常为weight,shape要求与x的后几维保持一致。
 - epsilon:float数据类型,用于防止除0错误。
 
     输出说明:
     
   - 第1个输出为Tensor,计算公式的最终输出y。
 - 第2个输出为Tensor,rms_norm的中间结果rstd,用于反向计算。
 
模型中替换代码及算子计算逻辑
     RmsNorm算子常见于LLaMA、LLaMA2、Baichuan等LLM模型中,由于torch侧没有提供RmsNorm算子的接口,因此在模型中通常是以自定义类的形式出现,在forward函数下定义计算逻辑,例如:
     
      
       
        
         
      
     
    
    
           1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44  | 
          
           class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): """ Apply the RMSNorm normalization to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ output = self._norm(x.float()).type_as(x) return output * self.weight  | 
         
     替换为:
     
      
       
        
         
      
     
    
    
           1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31  | 
          
           import torch_npu class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.eps)[0]  | 
         
     图1 计算流程
     
    
    
    融合后多了一个输出rstd,为计算中间结果,用于反向算子输入,具体如下。
torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
算子替换的模型中小算子

使用限制
     父主题: 融合算子替换