torch_npu.npu_prompt_flash_attention
功能描述
全量FA实现,实现对应公式:
接口原型
1 | torch_npu.npu_prompt_flash_attention(Tensor query, Tensor key, Tensor value, *, Tensor? pse_shift=None, padding_mask=None, Tensor? atten_mask=None, int[]? actual_seq_lengths=None, Tensor? deq_scale1=None, Tensor? quant_scale1=None, Tensor? deq_scale2=None, Tensor? quant_scale2=None, Tensor? quant_offset2=None, int num_heads=1, float scale_value=1.0, int pre_tokens=2147473647, int next_tokens=0, str input_layout="BSH", int num_key_value_heads=0, int[]? actual_seq_lengths_kv=None, int sparse_mode=0) -> Tensor |
参数说明

query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
- query:Tensor类型,公式中的输入Q,数据类型与key的数据类型需满足数据类型推导规则,即保持与key、value的数据类型一致。不支持非连续的Tensor,数据格式支持ND。
Atlas 推理系列加速卡产品 :数据类型支持float16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16、int8。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16、int8。
- key:Tensor类型,公式中的输入K,数据类型与query的数据类型需满足数据类型推导规则,即保持与query、value的数据类型一致。不支持非连续的Tensor,数据格式支持ND。
Atlas 推理系列加速卡产品 :数据类型支持float16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16、int8。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16、int8。
- value:Tensor类型,公式中的输入V,数据类型与query的数据类型需满足数据类型推导规则,即保持与query、key的数据类型一致。不支持非连续的Tensor,数据格式支持ND。
Atlas 推理系列加速卡产品 :数据类型支持float16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16、int8。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16、int8。
- *:代表其之前的变量是位置相关,需要按照顺序输入,必选;之后的变量是键值对赋值的,位置无关,可选(不输入会使用默认值)。
- pse_shift:Tensor类型,可选参数。不支持非连续的Tensor,数据格式支持ND。输入shape类型需为(B, N, Q_S, KV_S)或(1, N, Q_S, KV_S),其中Q_S为query的shape中的S,KV_S为key和value的shape中的S。对于pse_shift的KV_S为非32字节对齐的场景,建议padding到32字节来提高性能,多余部分的填充值不做要求。如不使用该功能时传入None。综合约束请见约束说明。
Atlas 推理系列加速卡产品 :暂不支持该参数。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16。当pse_shift为float16时,要求query为float16或int8;当pse_shift为bfloat16时,要求query为bfloat16。在query、key、value为float16且pse_shift存在的情况下,默认走高精度模式。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16。当pse_shift为float16时,要求query为float16或int8;当pse_shift为bfloat16时,要求query为bfloat16。在query、key、value为float16且pse_shift存在的情况下,默认走高精度模式。
- padding_mask:预留参数,暂未使用,默认值为None。
- atten_mask:Tensor类型,代表下三角全为0上三角全为负无穷的倒三角mask矩阵,数据类型支持bool、int8和uint8。数据格式支持ND,不支持非连续的Tensor。如果不使用该功能传入None。通常建议shape输入(Q_S, KV_S)、(B, Q_S, KV_S)、(1, Q_S, KV_S)、(B, 1, Q_S, KV_S)、(1, 1, Q_S, KV_S),其中Q_S为query的shape中的S,KV_S为key和value的shape中的S,对于attenMask的KV_S为非32字节对齐的场景,建议padding到32字节对齐来提高性能,多余部分填充成1。综合约束请见约束说明。
- actual_seq_lengths:int类型数组,代表不同Batch中query的有效seqlen,数据类型支持int64。如果不指定seqlen可以传入None,表示和query的shape的s长度相同。限制:该入参中每个batch的有效seqlen应该不大于query中对应batch的seqlen。seqlen的传入长度为1时,每个Batch使用相同seqlen;传入长度大于等于Batch数时取seqlen的前Batch个数。其它长度不支持。
Atlas 推理系列加速卡产品 :暂不支持该参数,传入None。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :支持TND格式。当query的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :支持TND格式。当query的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlen和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。
- deq_scale1:Tensor类型,表示BMM1后面的反量化因子,支持per-tensor。数据类型支持uint64、float32,数据格式支持ND。 如不使用该功能时可传入None。
Atlas 推理系列加速卡产品 暂不支持该参数。 - quant_scale1:Tensor类型,数据类型支持float32。数据格式支持ND,表示BMM2前面的量化因子,支持per-tensor。 如不使用该功能时可传入None。
Atlas 推理系列加速卡产品 暂不支持该参数。 - deq_scale2:Tensor类型,数据类型支持uint64、float32。数据格式支持ND,表示BMM2后面的反量化因子,支持per-tensor。 如不使用该功能时可传入None。
Atlas 推理系列加速卡产品 暂不支持该参数。 - quant_scale2:Tensor类型,数据格式支持ND,表示输出的量化因子,支持per-tensor、per-channel。如不使用该功能时可传入None。
Atlas 推理系列加速卡产品 :暂不支持该参数。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float32、bfloat16。当输入为bfloat16时,同时支持float32和bfloat16 ,否则仅支持float32 。per-channel格式,当输出layout为BSH时,要求quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D(建议输出layout为BSH时,quant_scale2 shape传入(1, 1, H)或(H,);输出为BNSD时,建议传入(1, N, 1, D)或(N, D);输出为BSND时,建议传入(1, 1, N, D)或(N, D))。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float32、bfloat16。当输入为bfloat16时,同时支持float32和bfloat16 ,否则仅支持float32 。per-channel格式,当输出layout为BSH时,要求quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D(建议输出layout为BSH时,quant_scale2 shape传入(1, 1, H)或(H,);输出为BNSD时,建议传入(1, N, 1, D)或(N, D);输出为BSND时,建议传入(1, 1, N, D)或(N, D))。
- quant_offset2:Tensor类型,数据格式支持ND,表示输出的量化偏移,支持per-tensor、per-channel。若传入quant_offset2,需保证其类型和shape信息与quant_scale2一致。如不使用该功能时可传入None。
Atlas 推理系列加速卡产品 :暂不支持该参数。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float32、bfloat16。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float32、bfloat16。
- num_heads:int类型数组,代表query的head个数,数据类型支持int64。
- scale_value:浮点型,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持float。数据类型与query的数据类型需满足数据类型推导规则。用户不特意指定时可传入默认值1.0。
- pre_tokens:int类型,用于稀疏计算,表示attention需要和前几个Token计算关联,数据类型支持int64。用户不特意指定时可传入默认值2147483647。
Atlas 推理系列加速卡产品 仅支持默认值2147483647。 - next_tokens:int类型,用于稀疏计算,表示attention需要和后几个Token计算关联。数据类型支持int64。用户不特意指定时可传入默认值0。
Atlas 推理系列加速卡产品 仅支持0和2147483647。 - input_layout:字符串类型,用于标识输入query、key、value的数据排布格式,当前支持BSH、BSND、BNSD、BNSD、BNSD_BSND(输入为BNSD时,输出格式为BSND)、TND(不支持pse、全量化、后量化)。用户不特意指定时可传入默认值"BSH"。
- num_key_value_heads:int类型,代表key、value中head个数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,数据类型支持int64。用户不特意指定时可传入默认值0,表示key/value和query的head个数相等。限制:需要满足num_heads整除num_key_value_heads,num_heads与num_key_value_heads的比值不能大于64,且在BSND、BNSD、BNSD_BSND场景下,需要与shape中的key/value的N轴shape值相同,否则报错。
Atlas 推理系列加速卡产品 仅支持默认值0。 - actual_seq_lengths_kv:int类型数组,代表不同batch中key/value的有效seqlenKV。数据类型支持int64。限制:该入参中每个batch的有效seqlenKV应该不大于key/value中对应batch的seqlenKV。seqlenKV的传入长度为1时,每个Batch使用相同seqlenKV;传入长度大于等于Batch数时取seqlenKV的前Batch个数,其它长度不支持。
Atlas 推理系列加速卡产品 :暂不支持该参数,传入None。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :支持TND格式。当key/value的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKV和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值Atlas A3 训练系列产品/Atlas A3 推理系列产品 :支持TND格式。当key/value的input_layout为TND时,该入参必须传入,且以该入参元素的数量作为Batch值。该入参中每个元素的值表示当前Batch与之前所有Batch的seqlenKV和,因此后一个元素的值必须大于等于前一个元素的值,且不能出现负值。
- sparse_mode:int类型,表示sparse的模式,数据类型支持int64。
Atlas 推理系列加速卡产品 仅支持默认值0。- sparse_mode为0时,代表defaultMask模式,如果atten_mask未传入则不做mask操作,忽略preTokens和nextTokens(内部赋值为INT_MAX);如果传入,则需要传入完整的atten_mask矩阵(S1 * S2),表示pre_tokens和next_tokens之间的部分需要计算。
- sparse_mode为1时,代表allMask。
- sparse_mode为2时,代表leftUpCausal模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
- sparse_mode为3时,代表rightDownCausal模式的mask,均对应以左顶点为划分的下三角场景,需要传入优化后的atten_mask矩阵(2048*2048)。
- sparse_mode为4时,代表band模式的mask,需要传入优化后的atten_mask矩阵(2048*2048)。
- sparse_mode为5、6、7、8时,分别代表prefix、global、dilated、block_local,均暂不支持。用户不特意指定时可传入默认值0。综合约束请见约束说明。
输出说明
atten_out:Tensor类型,计算的最终结果,shape与query保持一致。
Atlas 推理系列加速卡产品 :数据类型支持float16。Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 :数据类型支持float16、bfloat16、int8。Atlas A3 训练系列产品/Atlas A3 推理系列产品 :数据类型支持float16、bfloat16、int8。
约束说明
- 该接口仅在推理场景下使用。
- 该接口支持图模式(目前仅支持PyTorch 2.1版本)。
- 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
- 入参为空的处理:算子内部需要判断参数query是否为空,如果是空则直接返回。参数query不为空Tensor,参数key、value为空tensor(即S2为0),则填充全零的对应shape的输出(填充attention_out)。attention_out为空Tensor时,AscendCLNN框架会处理。
- query、key、value输入,功能使用限制如下:
产品型号
轴约束
Atlas 推理系列加速卡产品 - 支持B轴小于等于128。
- 支持N轴小于等于256。
- 支持S轴小于等于65535(64k)。
- 支持D轴小于等于512。
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 Atlas A3 训练系列产品/Atlas A3 推理系列产品 - 支持B轴小于等于65536(64k),D轴32byte不对齐时仅支持到128。
- 支持N轴小于等于256。
- S支持小于等于20971520(20M)。长序列场景下,如果计算量过大可能会导致PFA算子执行超时(aicore error类型报错,errorStr为timeout or trap error),此场景下建议做S切分处理,注:这里计算量会受B、S、N、D等的影响,值越大计算量越大。典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
- B=1,Q_N=20,Q_S=1048576,D = 256,KV_N=1,KV_S=1048576。
- B=1,Q_N=2,Q_S=10485760,D = 256,KV_N=2,KV_S=10485760。
- B=20,Q_N=1,Q_S=1048576,D = 256,KV_N=1,KV_S=1048576。
- B=1,Q_N=10,Q_S=1048576,D = 512,KV_N=1,KV_S=1048576。
- 支持D轴小于等于512。input_layout为BSH或者BSND时,要求N*D小于65535。
- TND场景下query,key,value输入的综合限制:
- B=1,Q_N=20,Q_S=1048576,D = 256,KV_N=1,KV_S=1048576。
- T小于等于65536;
- N等于8/16/32/64/128,且Q_N、K_N、V_N相等;
- Q_D、K_D等于192,V_D等于128/192;
- 数据类型仅支持BFLOAT16;
- sparse模式仅支持sparse=0且不传mask,或sparse=3且传入mask;
- 当sparse=3时,要求每个batch单独的actualSeqLengths < actualSeqLengthsKv。
- 参数sparse_mode当前仅支持值为0、1、2、3、4的场景,取其它值时会报错。
- sparse_mode=0时,atten_mask如果为None,则忽略入参pre_tokens、next_tokens(内部赋值为INT_MAX)。
- sparse_mode=2、3、4时,atten_mask的shape需要为(S, S)或(1, S, S)或(1, 1, S, S),其中S的值需要固定为2048,且需要用户保证传入的atten_mask为下三角,不传入atten_mask或者传入的shape不正确报错。
- sparse_mode=1、2、3的场景忽略入参pre_tokens、next_tokens并按照相关规则赋值。
- int8量化相关入参数量与输入、输出数据格式的综合限制:
- 输入为int8,输出为int8的场景:入参deq_scale1、quant_scale1、deq_scale2、quant_scale2需要同时存在,quant_offset2可选,不传时默认为0。
- 输入为int8,输出为float16的场景:入参deq_scale1、quant_scale1、deq_scale2需要同时存在,若存在入参quant_offset2或quant_scale2(即不为None),则报错并返回。
- 输入为float16或bfloat16,输出为int8的场景:入参quant_scale2需存在,quant_offset2可选,不传时默认为0,若存在入参deq_scale1或quant_scale1或deq_scale2(即不为None),则报错并返回。
- 入参quant_offset2和quant_scale2支持per-tensor/per-channel两种格式和float32/bfloat16两种数据类型。若传入quant_offset2,需保证其类型和shape信息与quant_scale2一致。当输入为bfloat16时,同时支持float32和bfloat16,否则仅支持float32。per-channel格式,当输出layout为BSH时,要求quant_scale2所有维度的乘积等于H;其他layout要求乘积等于N*D。当输出layout为BSH时,quant_scale2 shape建议传入(1, 1, H)或(H,);当输出为BNSD时,建议传入(1, N, 1, D)或(N, D);当输出为BSND时,建议传入(1, 1, N, D)或(N, D)。per-tensor格式,建议D轴对齐到32Byte。
- per-channel格式,入参quant_scale2和quant_offset2暂不支持左padding、Ring Attention或者D非32Byte对齐的场景。
- 输出为int8时,暂不支持sparse为band且pre_tokens/next_tokens为负数。
- pse_shift功能使用限制如下:
- 支持query数据类型为float16或bfloat16或int8场景下使用该功能。
- query,key,value数据类型为float16且pse_shift存在时,强制走高精度模式,对应的限制继承自高精度模式的限制。
- Q_S需大于等于query的S长度,KV_S需大于等于key的S长度。
- 输出为int8,入参quant_offset2传入非None和非空tensor值,并且sparse_mode、pre_tokens和next_tokens满足以下条件,矩阵会存在某几行不参与计算的情况,导致计算结果误差,该场景会拦截:
- sparseMode=0,atten_mask如果非None,每个batch actual_seq_lengths-actual_seq_lengths_kv-pre_tokens>0或nextTokens<0时,满足拦截条件。
- sparseMode=1或2,不会出现满足拦截条件的情况。
- sparseMode=3,每个batch actual_seq_lengths_kv- actual_seq_lengths<0,满足拦截条件。
- sparseMode= 4,preTokens<0或每个batch next_tokens+actual_seq_lengths_kv-actual_seq_lengths<0时,满足拦截条件。
- kv伪量化参数分离当前暂不支持。
- 暂不支持D不对齐场景。
支持的型号
Atlas A2 训练系列产品/Atlas 800I A2 推理产品/A200I A2 Box 异构组件 Atlas A3 训练系列产品/Atlas A3 推理系列产品 Atlas 推理系列加速卡产品
调用示例
- 单算子调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
import torch import torch_npu import math # 生成随机数据,并发送到npu q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu() k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) actseqlen = [164] actseqlenkv = [1024] # 调用PFA算子 out = torch_npu.npu_prompt_flash_attention(q, k, v, actual_seq_lengths = actseqlen, actual_seq_lengths_kv = actseqlenkv, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) # 执行上述代码的输出类似如下 tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16)
- 图模式调用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
# 入图方式 import torch import torch_npu import math import torchair as tng from torchair.configs.compiler_config import CompilerConfig import torch._dynamo TORCHDYNAMO_VERBOSE=1 TORCH_LOGS="+dynamo" # 支持入图的打印宏 import logging from torchair.core.utils import logger logger.setLevel(logging.DEBUG) config = CompilerConfig() config.debug.graph_dump.type = "pbtxt" npu_backend = tng.get_npu_backend(compiler_config=config) from torch.library import Library, impl # 数据生成 q = torch.randn(1, 8, 164, 128, dtype=torch.float16).npu() k = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() v = torch.randn(1, 8, 1024, 128, dtype=torch.float16).npu() scale = 1/math.sqrt(128.0) class Model(torch.nn.Module): def __init__(self): super().__init__() def forward(self): return torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) def MetaInfershape(): with torch.no_grad(): model = Model() model = torch.compile(model, backend=npu_backend, dynamic=False, fullgraph=True) graph_output = model() single_op = torch_npu.npu_prompt_flash_attention(q, k, v, num_heads = 8, input_layout = "BNSD", scale_value=scale, pre_tokens=65535, next_tokens=65535) print("single op output with mask:", single_op, single_op.shape) print("graph output with mask:", graph_output, graph_output.shape) if __name__ == "__main__": MetaInfershape() # 执行上述代码的输出类似如下 single op output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128]) graph output with mask: tensor([[[[ 0.0219, 0.0201, 0.0049, ..., 0.0118, -0.0011, -0.0140], [ 0.0294, 0.0256, -0.0081, ..., 0.0267, 0.0067, -0.0117], [ 0.0285, 0.0296, 0.0011, ..., 0.0150, 0.0056, -0.0062], ..., [ 0.0177, 0.0194, -0.0060, ..., 0.0226, 0.0029, -0.0039], [ 0.0180, 0.0186, -0.0067, ..., 0.0204, -0.0045, -0.0164], [ 0.0176, 0.0288, -0.0091, ..., 0.0304, 0.0033, -0.0173]]]], device='npu:0', dtype=torch.float16) torch.Size([1, 8, 164, 128])
父主题: torch_npu