ScaledMaskedSoftmax & ScaledMaskedSoftmaxGrad
算子基础信息
| 
          算子名称  | 
        
          ScaledMaskedSoftmax & ScaledMaskedSoftmaxGrad  | 
       
|---|---|
| 
          torch_npu api接口  | 
        
          torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)  | 
       
| 
          支持的torch_npu版本  | 
        
          2.1.0, 2.3.1, 2.4.0  | 
       
| 
          支持的芯片类型  | 
        
          | 
       
| 
          支持的数据类型  | 
        
          float16, bfloat16, float  | 
       
算子IR及torch_npu接口参数
           1 2 3 4 5 6 7  | 
          
           REG_OP(ScaledMaskedSoftmax) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_BF16})) .OPTIONAL_INPUT(mask, TensorType({DT_BOOL, DT_UINT1})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_BF16})) .ATTR(scale, Float, 1.0) .ATTR(fixed_triu_mask, Bool, false) .OP_END_FACTORY_REG(ScaledMaskedSoftmax)  | 
         
torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)
| 
           名称  | 
         
           类型  | 
         
           Dtype  | 
         
           Shape要求  | 
         
           默认值  | 
        
|---|---|---|---|---|
| 
           x  | 
         
           输入  | 
         
           bfloat16, float16, float32  | 
         
           必须为4维,且后两维都需要在[32, 4096]范围内,且能被32整除  | 
         
           -  | 
        
| 
           mask  | 
         
           输入  | 
         
           bool  | 
         
           必须为4维,且后两维和x一致,且能被广播成x的shape  | 
         
           -  | 
        
| 
           scale  | 
         
           属性  | 
         
           float  | 
         
           对输入x缩放  | 
         
           1.0  | 
        
| 
           fixed_triu_mask  | 
         
           属性  | 
         
           bool  | 
         
           是否生成可用的上三角bool掩码  | 
         
           False  | 
        
模型中替换代码及算子计算逻辑
           1 2 3 4 5 6 7 8 9 10 11 12 13  | 
          
           if self.input_in_float16 and self.softmax_in_fp32: input = input.float() if self.scale is not None: input = input * self.scale mask_output = self.mask_func(input, mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1)(mask_output) if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: probs = probs.half() else: probs = probs.bfloat16()  | 
         
           1
            | 
          
           probs = torch_npu.npu_scaled_masked_softmax(input , mask, self.scale, fixed_triu_mask )  | 
         
           1 2 3  | 
          
           if fixed_triu_mask: mask = torch.triu(mask.shape, diagonal=1) y = torch.softmax((x * scale).masked_fill(mask, -inf), dim=-1)  | 
         
    算子替换的模型中小算子

使用限制
- 输入x的shape限制如下:
      
- 必须为4维
 - 第三维的取值需要在[32, 4096]范围内
 - 第四维的取值需要在[32, 4096]范围内
 - 第三维的取值需要能被32整除
 - 第四维的取值需要能被32整除
 
 - 输入mask的shape限制如下:
      
- 必须为4维
 - 后两维必须与x的后两维相等
 - 前两维需要能被广播成x的前两维
 
 
已支持模型典型case
| 
          id  | 
        
          x  | 
        
          mask  | 
       
|---|---|---|
| 
          1  | 
        
          [1, 8, 4096, 4096]  | 
        
          [1, 1, 4096, 4096]  | 
       
| 
          2  | 
        
          [4, 32, 2048, 2048]  | 
        
          [4, 1, 2048, 2048]  | 
       
| 
          3  | 
        
          [8, 16, 512, 2048]  | 
        
          [8, 16, 512, 2048]  | 
       
| 
          4  | 
        
          [8, 16, 512, 1536]  | 
        
          [8, 16, 512, 1536]  | 
       
| 
          5  | 
        
          [8, 16, 512, 1024]  | 
        
          [8, 16, 512, 1024]  | 
       
| 
          6  | 
        
          [8, 16, 512, 512]  | 
        
          [8, 16, 512, 512]  | 
       
| 
          7  | 
        
          [8, 16, 512, 256]  | 
        
          [8, 16, 512, 256]  | 
       
| 
          8  | 
        
          [4, 4, 2048, 2048]  | 
        
          [4, 4, 2048, 2048]  |