实现“Transformer Attention Score”的融合计算,实现的计算公式如下:
torch_npu.npu_fusion_attention(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, int inner_precise=0, int[]? prefix=None, int[]? actual_seq_qlen=None, int[]? actual_seq_kvlen=None, int sparse_mode=0, bool gen_mask_parallel=True, bool sync=False) -> (Tensor, Tensor, Tensor, Tensor, int, int, int)
说明:当前0、1为保留配置值,当计算过程中存在整行mask进而导致精度有损失时,可以尝试将该参数配置为2以提升精度,但是该配置可能会导致性能下降。
则actual_seq_qlen传:2 4 6 8 10
则actual_seq_qlen传:2 4 6 8 10
atten_mask的工作原理为,在Mask为True的位置遮蔽query(Q)与key(K)的转置矩阵乘积的值,示意如下:
矩阵在atten_mask为Ture的位置会被遮蔽,效果如下:
atten_mask应传入下三角矩阵,示意如下:
atten_mask应传入band形状矩阵,示意如下:
说明:next_tockens为负数时,pre_tockens取值必须大于next_tockens的绝对值。
该场景下忽略next_tockens、pre_tockens取值(内部赋值为INT_MAX),矩阵示意如下:
传入的atten_mask为优化后的压缩下三角矩阵(2048*2048),压缩下三角矩阵示意(下同):
该场景下忽略next_tockens、pre_tockens取值,atten_mask矩阵数据格式须为BNSS或B1SS,示意如下:
atten_mask应传入矩阵示意如下:
共7个输出
(Tensor, Tensor, Tensor, Tensor, int, int, int)
Atlas A2 训练系列产品
import math import unittest import numpy as np import torch import torch_npu from torch_npu.testing.testcase import TestCase, run_tests from torch_npu.testing.common_utils import SupportedDevices class TestNPUFlashAttention(TestCase): def supported_op_exec(self, query, key, value, atten_mask): scale = 0.08838 qk = torch.matmul(query, key.transpose(2, 3)).mul(scale) qk = qk + atten_mask * (-10000.0) softmax_res = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32).to(torch.float16) attention_out = torch.matmul(softmax_res, value) return attention_out def custom_op_exec(self, query, key, value, sparse_params): scale = 0.08838 atten_mask = None if sparse_params[0] == 0: shape = [1, 8, 256, 256] atten_mask_u = np.triu(np.ones(shape), k=sparse_params[1] + 1) atten_mask_l = np.tril(np.ones(shape), k=-sparse_params[2] - 1) atten_masks = atten_mask_u + atten_mask_l atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu() if sparse_params[0] == 2 or sparse_params[0] == 3 or sparse_params[0] == 4: atten_masks = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)) atten_mask = torch.tensor(atten_masks).to(torch.float16).bool().npu() return torch_npu.npu_fusion_attention( query, key, value, head_num=8, input_layout="BNSD", scale=scale, sparse_mode=sparse_params[0], atten_mask=atten_mask, pre_tockens=sparse_params[1], next_tockens=sparse_params[2]) def get_atten_mask(self, sparse_mode=0, pre_tokens=65536, next_tokens=65536): atten_masks = [] shape = [1, 8, 256, 256] if sparse_mode == 0: atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1) atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1) atten_masks = atten_mask_u + atten_mask_l elif sparse_mode == 1: atten_masks = np.zeros(shape) pre_tokens = 65536 next_tokens = 65536 elif sparse_mode == 2: atten_masks = np.triu(np.ones(shape), k=1) elif sparse_mode == 3: atten_masks = np.triu(np.ones(shape), k=1) elif sparse_mode == 4: atten_mask_u = np.triu(np.ones(shape), k=next_tokens + 1) atten_mask_l = np.tril(np.ones(shape), k=-pre_tokens - 1) atten_masks = atten_mask_u + atten_mask_l atten_mask = torch.tensor(atten_masks).to(torch.float16) return atten_mask # sparse_params = [sparse_mode, pre_tokens, next_tokens] def check_result(self, query, key, value, sparse_params): atten_mask = self.get_atten_mask(sparse_params[0], sparse_params[1], sparse_params[2]) output = self.supported_op_exec(query, key, value, atten_mask) fa_result = self.custom_op_exec(query.npu(), key.npu(), value.npu(), sparse_params) self.assertRtolEqual(output, fa_result[0], prec=0.01, prec16=0.01) def test_npu_flash_attention(self, device="npu"): query = torch.randn(1, 8, 256, 256, dtype=torch.float16) key = torch.randn(1, 8, 256, 256, dtype=torch.float16) value = torch.randn(1, 8, 256, 256, dtype=torch.float16) # sparse_params: [sparse_mode, pre_tokens, next_tokens] sparse_params_list = [ [0, 128, 128], [1, 65536, 65536], [2, 65536, 0], [3, 65536, 0], [4, 128, 128] ] for sparse_params in sparse_params_list: self.check_result(query, key, value, sparse_params) if __name__ == "__main__": run_tests()