全量FA实现,实现对应公式:
torch_npu.npu_prompt_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, padding_mask=None, Tensor? atten_mask=None, int[]? actual_seq_lengths=None, Tensor? deq_scale1=None, Tensor? quant_scale1=None, Tensor? deq_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147473647, int next_tokens=0, str input_layout="BSH", int num_key_value_heads=0, int[]? actual_seq_lengths_kv=None, int sparse_mode=0) -> Tensor
当输入为BFLOAT16时,同时支持 FLOAT32和BFLOAT16 ,否则仅支持 FLOAT32 。per-channel 格式,当输出layout为BSH时,要求 quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D(建议输出layout为BSH时,quant_scale2 shape传入[1,1,H]或[H];输出为BNSD时,建议传入[1,N,1,D]或[N,D];输出为BSND时,建议传入[1,1,N,D]或[N,D])。
共一个输出,为计算的最终结果atten_out,类型为Tensor,shape与query保持一致。·
PyTorch 2.1
# 单算子调用方式 import torch import torch_npu import math # 生成随机数据,并发送到npu q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu() k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) actseqlen = [164] actseqlenkv = [1024] # 调用PFA算子 out = torch_npu.npu_prompt_flash_attention(q, k, v, actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) # 执行上述代码的输出类似如下 tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16)
# 入图方式 import torch import torch_npu import math import torchair as tng from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.aoe_config.aoe_mode = "2" config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl # 数据生成 q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu() k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() single_op = torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) print("single op output with mask:", single_op, single_op.shape) print("graph output with mask:", graph_output, graph_output.shape) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128]) graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])