昇腾社区首页
中文
注册

npu_fused_infer_attention_score

支持的型号

Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件

Atlas A3 训练系列产品/Atlas A3 推理系列产品

功能说明

推理场景下,Ascend Extension for PyTorch提供的torch_npu.npu_fused_infer_attention_score(参考Ascend Extension for PyTorch 自定义API参考),适配增量和全量推理场景的FlashAttention算子,既可以支持全量计算场景(PromptFlashAttention),也可支持增量计算场景(IncreFlashAttention)。当Query矩阵的S为1,进入IncreFlashAttention分支,其余场景进入PromptFlashAttention分支。

该接口在图模式场景下,如果开启Tiling调度优化功能(config.experimental_config.tiling_schedule_optimize),模型中actual_seq_length类参数会存在Host到Device的拷贝开销,模型执行性能会下降。为此,TorchAir提供了相应的定制化接口,保障该算子Tiling调度优化效果。

本接口在Tiling下沉模式下,提供actual_seq_length类参数直接传DeviceTensor的能力。原理是actual_seq_length类参数用于Tiling分核和Kernel计算,Tiling下沉时AI CPU中的Tiling分核和AI Core中的Kernel计算均在Device侧,直接传入Device可以减少Host到Device拷贝,从而降低开销。

函数原型

npu_fused_infer_attention_score(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, Tensor? atten_mask=None, Tensor? actual_seq_lengths=None, Tensor? actual_seq_lengths_kv=None, Tensor? dequant_scale1=None, Tensor? quant_scale1=None, Tensor? dequant_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, Tensor? antiquant_scale=None, Tensor? antiquant_offset=None, Tensor? key_antiquant_scale=None, Tensor? key_antiquant_offset=None, Tensor? value_antiquant_scale=None, Tensor? value_antiquant_offset=None, Tensor? block_table=None, Tensor? query_padding_size=None, Tensor? kv_padding_size=None, Tensor? key_shared_prefix=None, Tensor? value_shared_prefix=None, Tensor? actual_shared_prefix_len=None, Tensor? query_rope=None, Tensor? key_rope=None, Tensor? key_rope_antiquant_scale=None, int num_heads=1, float scale=1.0, int pre_tokens=2147483647, int next_tokens=2147483647, str input_layout="BSH", int num_key_value_heads=0, int sparse_mode=0, int inner_precise=0, int block_size=0, int antiquant_mode=0, int key_antiquant_mode=0, int value_antiquant_mode=0, bool softmax_lse_flag=False) -> (Tensor, Tensor)

参数说明

  • actual_seq_length类参数:本接口是指actual_seq_lengths、actual_seq_lengths_kv、actual_shared_prefix_len参数。
  • 与torch_npu.npu_fused_infer_attention_score接口相比,参数区别在actual_seq_length类参数类型支持Tensor,而非int型数组。

参数

输入/输出

说明

是否必选

其他参数

输入

torch_npu.npu_fused_infer_attention_score接口同名参数要求一致。

-

actual_seq_lengths

输入

Tensor类型,代表不同Batch中query的有效Sequence Length,数据类型支持int64。

actual_seq_lengths_kv

输入

Tensor类型,代表不同Batch中key/value的有效Sequence Length,数据类型支持int64。

actual_shared_prefix_len

输入

Tensor类型,代表key_shared_prefix/value_shared_prefix的有效Sequence Length。数据类型支持int64。

返回值说明

与torch_npu.npu_fused_infer_attention_score接口返回值说明一致。

约束说明

  • 该接口只支持图模式,不支持Eager模式下调用。
  • 不支持reduce-overhead执行模式。
  • 其他约束说明与torch_npu.npu_fused_infer_attention_score接口保持一致。

调用示例

import torch
import torch_npu
import math
import torchair as tng

from torchair.configs.compiler_config import CompilerConfig
import torch._dynamo
TORCHDYNAMO_VERBOSE=1
TORCH_LOGS="+dynamo"

# 支持入图的打印宏
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
config = CompilerConfig()
config.debug.graph_dump.type = "pbtxt"
npu_backend = tng.get_npu_backend(compiler_config=config)
from torch.library import Library, impl

# 数据生成
q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu()
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
k_prefix = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
v_prefix = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu()
actualSeqLengthkvs = [512]
actualSeqLengthkvs = torch.tensor(actualSeqLengthkvs).npu()
actualSeqLengths = [50]
actualSeqLengths = torch.tensor(actualSeqLengths).npu()
actualSeqLengthsPrefix = [50]
actualSeqLengthsPrefix = torch.tensor(actualSeqLengthsPrefix).npu()
scale = 1/math.sqrt(128.0)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self):
        return tng.ops.npu_fused_infer_attention_score(q, k, v, actual_seq_lengths = actualSeqLengths, actual_seq_lengths_kv = actualSeqLengthkvs, key_shared_prefix = k_prefix, value_shared_prefix = v_prefix, actual_shared_prefix_len = actualSeqLengthsPrefix, num_heads = 8, input_layout = "BNSD", scale=scale, pre_tokens=65535, next_tokens=65535)
def MetaInfershape():
    with torch.no_grad():
        model = Model()
        model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True)
        graph_output = model()
    print("graph output with mask:", graph_output[0], graph_output[0].shape)
if __name__ == "__main__":
    MetaInfershape()

# 执行上述代码的输出类似如下
graph output with mask: tensor([[[[ 0.0219,  0.0201,  0.0049,  ...,  0.0118, -0.0011, -0.0140],
        [ 0.0294,  0.0256, -0.0081,  ...,  0.0267,  0.0067, -0.0117],
        [ 0.0285,  0.0296,  0.0011,  ...,  0.0150,  0.0056, -0.0062],
        ...,
        [ 0.0177,  0.0194, -0.0060,  ...,  0.0226,  0.0029, -0.0039],
        [ 0.0180,  0.0186, -0.0067,  ...,  0.0204, -0.0045, -0.0164],
        [ 0.0176,  0.0288, -0.0091,  ...,  0.0304,  0.0033, -0.0173]]]],
        device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])