ScaledMaskedSoftmax算子使用指南

表1 ScaledMaskedSoftmax算子基础信息

算子名称

ScaledMaskedSoftmax

torch_npu API接口

torch_npu.npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask)

支持的torch_npu版本

1.11, 2.0, 2.1

支持的昇腾产品

Atlas 200/300/500 推理产品Atlas 推理系列产品Atlas 训练系列产品Atlas A2 训练系列产品

支持的数据类型

float16,bfloat16,float32

算子IR及torch_npu接口参数

模型中替换代码及算子计算逻辑

算子替换的模型中小算子

使用限制

  1. 输入x的shape限制如下:
    1. 必须为4维。
    2. 第三维的取值需要在[32, 4096]范围内。
    3. 第四维的取值需要在[32, 4096]范围内。
    4. 第三维的取值需要能被32整除。
    5. 第四维的取值需要能被32整除。
  2. 输入mask的shape限制如下:
    1. 必须为4维。
    2. 后两维必须与x的后两维相等。
    3. 前两维需要能被广播成x的前两维。

已支持模型典型Case

如下case均包含fp16、fp32、bf16。

id

x

mask

1

[1, 8, 4096, 4096]

[1, 1, 4096, 4096]

2

[4, 32, 2048, 2048]

[4, 1, 2048, 2048]

3

[8, 16, 512, 2048]

[8, 16, 512, 2048]

4

[8, 16, 512, 1536]

[8, 16, 512, 1536]

5

[8, 16, 512, 1024]

[8, 16, 512, 1024]

6

[8, 16, 512, 512]

[8, 16, 512, 512]

7

[8, 16, 512, 256]

[8, 16, 512, 256]

8

[4, 4, 2048, 2048]

[4, 4, 2048, 2048]