昇腾社区首页
中文
注册

aclnnIncreFlashAttentionV2

Atlas 训练系列产品不支持该算子。

Atlas A2训练系列产品支持该算子。

接口原型

每个算子分为两段式接口,必须先调用“aclnnIncreFlashAttentionV2GetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnIncreFlashAttentionV2”接口执行计算。

  • aclnnStatus aclnnIncreFlashAttentionV2GetWorkspaceSize(const aclTensor *query, const aclTensorList *key, const aclTensorList *value, const aclTensor *paddingMask, const aclTensor *attenMask, const aclIntArray *actualSeqLengths, const aclTensor *dequantScale1, const aclTensor *quantScale1, const aclTensor *dequantScale2, const aclTensor *quantScale2, const aclTensor *quantOffset2, int64_t numHeads, double scaleValue, char *inputLayout, int64_t numKeyValueHeads, const aclTensor *attentionOut, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnstatus aclnnIncreFlashAttentionV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

功能描述

  • 算子功能:兼容aclnnIncreFlashAttention接口功能,在其基础上新增量化特性

    对于自回归(Auto-regressive)的语言模型,随着新词的生成,推理输入长度不断增大。在原来全量推理的基础上实现增量推理,query的S轴固定为1,key和value是经过KV Cache后,将之前推理过的state信息,叠加在一起,每个Batch对应S轴的实际长度可能不一样,输入的数据是经过padding后的固定长度数据。

    相比全量场景的FlashAttention算子(aclnnPromptFlashAttention),增量推理的流程与正常全量推理并不完全等价,不过增量推理的精度并无明显劣化。

    KV Cache是大模型推理性能优化的一个常用技术。采样时,Transformer模型会以给定的prompt/context作为初始输入进行推理(可以并行处理),随后逐一生成额外的token来继续完善生成的序列(体现了模型的自回归性质)。在采样过程中,Transformer会执行自注意力操作,为此需要给当前序列中的每个项目(无论是prompt/context还是生成的token)提取键值(KV)向量。这些向量存储在一个矩阵中,通常被称为kv缓存(KV Cache)。

  • 计算公式

    self-attention(自注意力)利用输入样本自身的关系构建了一种注意力模型。其原理是假设有一个长度为n的输入样本序列x,x的每个元素都是一个d维向量,可以将每个d维向量看作一个token embedding,将这样一条序列经过3个权重矩阵变换得到3个维度为n*d的矩阵。

    self-attention的计算公式一般定义如下,其中Q、K、V为输入样本的重要属性元素,是输入样本经过空间变换得到,且可以统一到一个特征空间中。

    本算子中Score函数采用Softmax函数,self-attention计算公式为

    其中Q和KT的乘积代表输入x的注意力,为避免该值变得过大,通常除以d的开根号进行缩放,并对每行进行softmax归一化,与V相乘后得到一个n*d的矩阵。

aclnnIncreFlashAttentionV2GetWorkspaceSize

  • 参数说明:
    • query(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入Q,数据类型支持FLOAT16/INT8/BFLOAT16,数据格式支持ND。
    • key(aclTensorList*,计算输入):Device侧的aclTensorList,公式中的输入K,数据类型支持FLOAT16/INT8/BFLOAT16,数据格式支持ND。
    • value(aclTensorList*,计算输入):Device侧的aclTensorList,公式中的输入V,数据类型支持FLOAT16/INT8/BFLOAT16 ,数据格式支持ND。
    • paddingMask(aclTensor*,计算输入):Device侧的aclTensor,暂不生效,数据类型支持FLOAT16,数据格式支持ND。
    • attenMask(aclTensor*,计算输入):Device侧的aclTensor,可选参数,数据类型支持BOOL、FLOAT16,数据格式支持ND。
    • actualSeqLengths(aclIntArray*,计算输入):Host侧的aclIntArray,可选参数,数据类型支持INT64。
    • dequantScale1(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:UINT64。数据格式支持ND,表示BMM1后面反量化的量化因子,支持pre-tensor(scalar)。 如不使用该功能时可传入nullptr。
    • quantScale1(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT。数据格式支持ND,表示BMM2前面量化的量化因子,支持pre-tensor(scalar)。 如不使用该功能时可传入nullptr。
    • dequantScale2(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:UINT64。数据格式支持ND,表示BMM2后面量化的量化因子,支持pre-tensor(scalar)。 如不使用该功能时可传入nullptr。
    • quantScale2(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT。数据格式支持ND,表示输出量化的量化因子,支持pre-tensor(scalar)。 如不使用该功能时可传入nullptr。
    • quantOffset2(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持:FLOAT。数据格式支持ND,表示输出量化的量化偏移,支持pre-tensor(scalar)。 如不使用该功能时可传入nullptr。
    • numHeads(int64_t,计算输入 ):Host侧的int64_t,代表head个数,数据类型支持INT64。
    • scaleValue(double,计算输入):Host侧的double,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE。
    • inputLayout(char*,计算输入):Host侧的字符指针,用于标识输入query、key、value的数据排布格式,当前支持BSH、BNSD。

      query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。

    • numKeyValueHeads(int64_t,计算输入 ):Host侧的int64_t,代表key、value中head个数,用于支持GQA(Grouped-Query Attention,分组查询注意力)场景,默认为0,表示和query的head个数相等,数据类型支持INT64。
    • attentionOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的输出,数据类型支持FLOAT16,数据格式支持ND。
    • workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。
    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    • 返回161001(ACLNN_ERR_PARAM_NULLPTR):如果传入参数是必选输入、输出或者必选属性,且是空指针,则返回161001。
    • 返回161002(ACLNN_ERR_PARAM_INVALID):query、key、value、paddingMask、attenMask、attentionOut的数据类型和数据格式不在支持的范围内。

aclnnIncreFlashAttentionV2

  • 参数说明:
    • workspace(void*,入参):在Device侧申请的workspace内存起址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由aclnnIncreFlashAttentionV2GetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束与限制

  • 参数key、value 中对应tensor的shape需要完全一致;非连续场景下 key、value 的tensorlist中的tensor的shape也需要完全一致,且batch只能为1。
  • 参数query和attentionOut的shape需要完全一致。
  • 参数query中的N和numHeads值相等,key、value的N和numKeyValueHeads值相等,并且numHeads是numKeyValueHeads的倍数关系。
  • 非连续场景下,参数key、value的tensorlist中tensor的个数等于query的B,shape需要完全一致。
  • D的限制为16k,大于16k会报错拦截。
  • int8量化相关入参数量与输入、输出数据格式的综合限制:
    1. 输入为INT8,输出为INT8的场景:入参dequantScale1、quantScale1、dequantScale2、quantScale2、quantOffset2需要同时存在。
    2. 输入为INT8,输出为FLOAT16的场景:入参dequantScale1、quantScale1、dequantScale2需要同时存在,不能传入quantScale2、quantOffset2(传nullptr)参数。
    3. 输入为FLOAT16,输出为INT8的场景:入参quantScale2、quantOffset2需要同时存在,不能传入dequantScale1、quantScale1、dequantScale2(传nullptr)参数。
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <iostream>
#include <vector>
#include "acl/acl.h"
#include "aclnnop/aclnn_add.h"
#include "aclnn_incre_flash_attention_v2.h"

#define CHECK_RET(cond, return_expr) \
  do {                               \
    if (!(cond)) {                   \
      return_expr;                   \
    }                                \
  } while (0)

#define LOG_PRINT(message, ...)     \
  do {                              \
    printf(message, ##__VA_ARGS__); \
  } while (0)

int64_t GetShapeSize(const std::vector<int64_t>& shape) {
  int64_t shape_size = 1;
  for (auto i : shape) {
    shape_size *= i;
  }
  return shape_size;
}

int Init(int32_t deviceId, aclrtContext* context, aclrtStream* stream) {
  // 固定写法,AscendCL初始化
  auto ret = aclInit(nullptr);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
  ret = aclrtSetDevice(deviceId);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
  ret = aclrtCreateContext(context, deviceId);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret); return ret);
  ret = aclrtSetCurrentContext(*context);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret); return ret);
  ret = aclrtCreateStream(stream);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
  return 0;
}

template <typename T>
int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
                    aclDataType dataType, aclTensor** tensor) {
  auto size = GetShapeSize(shape) * sizeof(T);
  // 调用aclrtMalloc申请device侧内存
  auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
  // 调用aclrtMemcpy将Host侧数据拷贝到device侧内存上
  ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
  // 计算连续tensor的strides
  std::vector<int64_t> strides(shape.size(), 1);
  for (int64_t i = shape.size() - 2; i >= 0; i--) {
    strides[i] = shape[i + 1] * strides[i + 1];
  }
  // 调用aclCreateTensor接口创建aclTensor
  *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
                            shape.data(), shape.size(), *deviceAddr);
  return 0;
}

int main() {
  // 1. (固定写法)device/context/stream初始化, 参考AscendCL对外接口列表
  // 根据自己的实际device填写deviceId
  int32_t deviceId = 0;
  aclrtContext context;
  aclrtStream stream;
  auto ret = Init(deviceId, &context, &stream);
  // check根据自己的需要处理
  CHECK_RET(ret == 0, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);

  // 2. 构造输入与输出,需要根据API的接口自定义构造
  std::vector<int64_t> selfShape = {4, 2};
  std::vector<int64_t> otherShape = {4, 2};
  std::vector<int64_t> outShape = {4, 2};
  void* selfDeviceAddr = nullptr;
  void* otherDeviceAddr = nullptr;
  void* outDeviceAddr = nullptr;
  aclTensor* self = nullptr;
  aclTensor* other = nullptr;
  aclScalar* alpha = nullptr;
  aclTensor* out = nullptr;
  std::vector<float> selfHostData = {0, 1, 2, 3, 4, 5, 6, 7};
  std::vector<float> otherHostData = {1, 1, 1, 2, 2, 2, 3, 3};
  std::vector<float> outHostData = {0, 0, 0, 0, 0, 0, 0, 0};
  float alphaValue = 1.2f;
  // 创建self aclTensor
  ret = CreateAclTensor(selfHostData, selfShape, &selfDeviceAddr, aclDataType::ACL_FLOAT, &self);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建other aclTensor
  ret = CreateAclTensor(otherHostData, otherShape, &otherDeviceAddr, aclDataType::ACL_FLOAT, &other);
  CHECK_RET(ret == ACL_SUCCESS, return ret);
  // 创建alpha aclScalar
  alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
  CHECK_RET(alpha != nullptr, return ret);
  // 创建out aclTensor
  ret = CreateAclTensor(outHostData, outShape, &outDeviceAddr, aclDataType::ACL_FLOAT, &out);
  CHECK_RET(ret == ACL_SUCCESS, return ret);

  // 3.调用CANN算子库API,需要修改为具体的算子接口
  uint64_t workspaceSize = 0;
  aclOpExecutor* executor;
  // 调用aclnnIncreFlashAttentionV2第一段接口
  ret = aclnnIncreFlashAttentionV2GetWorkspaceSize(tensors[q_idx], tensor_key_list, tensor_value_list,
                                                   nullptr, tensors[attenMask_idx], actual_seq_lengths, 
                                                   tensors[dequantScale1_idx],
                                                   tensors[quantScale1_idx],
                                                   tensors[dequantScale2_idx],
                                                   tensors[quantScale2_idx],
                                                   tensors[quantOffset2_idx],
                                                   num_heads,
                                                   scale_value,
                                                   layer_out,
                                                   numKeyValueHeads,
                                                   tensors[out_idx], &workspace_size, &handle);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnAddGetWorkspaceSize failed. ERROR: %d\n", ret); return ret);
  // 根据第一段接口计算出的workspaceSize申请device内存
  void* workspaceAddr = nullptr;
  if (workspaceSize > 0) {
    ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret;);
  }
  // 调用aclnnIncreFlashAttentionV2第二段接口
  ret = aclnnIncreFlashAttentionV2(workspace, workspace_size, handle, stream);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnAdd failed. ERROR: %d\n", ret); return ret);

  // 4.( 固定写法)同步等待任务执行结束
  ret = aclrtSynchronizeStream(stream);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

  // 5. 获取输出的值,将device侧内存上的结果拷贝至Host侧,需要根据具体API的接口定义修改
  auto size = GetShapeSize(outShape);
  std::vector<float> resultData(size, 0);
  ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), outDeviceAddr, size * sizeof(float),
                    ACL_MEMCPY_DEVICE_TO_HOST);
  CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);
  for (int64_t i = 0; i < size; i++) {
    LOG_PRINT("result[%ld] is: %f\n", i, resultData[i]);
  }

  // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
  aclDestroyTensor(self);
  aclDestroyTensor(other);
  aclDestroyScalar(alpha);
  aclDestroyTensor(out);
  
  // 7. 释放device资源,需要根据具体API的接口定义修改
  aclrtFree(selfDeviceAddr);
  aclrtFree(otherDeviceAddr);
  aclrtFree(outDeviceAddr);
  if (workspaceSize > 0) {
    aclrtFree(workspaceAddr);
  }
  aclrtDestroyStream(stream);
  aclrtDestroyContext(context);
  aclrtResetDevice(deviceId);
  aclFinalize();
  return 0;
}