torch_npu.npu_prompt_flash_attention

功能描述

全量FA实现,实现对应公式:

接口原型

torch_npu.npu_prompt_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, padding_mask=None, Tensor? atten_mask=None, int[]? actual_seq_lengths=None, Tensor? deq_scale1=None, Tensor? quant_scale1=None, Tensor? deq_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147473647, int next_tokens=0, str input_layout="BSH", int num_key_value_heads=0, int[]? actual_seq_lengths_kv=None, int sparse_mode=0) -> Tensor

参数说明

输出说明

共一个输出,为计算的最终结果atten_out,类型为Tensor,shape与query保持一致。

约束说明

支持的型号

调用示例