昇腾社区首页
中文
注册

aclnnFlashAttentionUnpaddingScoreGradV4

支持的产品型号

  • Atlas A2 训练系列产品
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品

函数原型

每个算子分为两段式接口,必须先调用“aclnnFlashAttentionUnpaddingScoreGradV4GetWorkspaceSize”接口获取计算所需workspace大小以及包含了算子计算流程的执行器,再调用“aclnnFlashAttentionUnpaddingScoreGradV4”接口执行计算。

  • aclnnStatus aclnnFlashAttentionUnpaddingScoreGradV4GetWorkspaceSize(const aclTensor *query, const aclTensor *queryRope, const aclTensor *keyIn, const aclTensor *keyInRope, const aclTensor *value, const aclTensor *dy, const aclTensor *pseShiftOptional, const aclTensor *dropMaskOptional, const aclTensor *paddingMaskOptional, const aclTensor *attenMaskOptional, const aclTensor *softmaxMaxOptional, const aclTensor *softmaxSumOptional, const aclTensor *softmaxInOptional, const aclTensor *attentionInOptional, const aclIntArray *prefixOptional, const aclIntArray *actualSeqQLenOptional, const aclIntArray *actualSeqKvLenOptional, const aclIntArray *qStartIdxOptional, const aclIntArray *kvStartIdxOptional, double scaleValue, double keepProb, int64_t preTokens, int64_t nextTokens, int64_t headNum, char *inputLayout, int64_t innerPrecise, int64_t sparseMode, int64_t pseType, const aclTensor *dqOut, const aclTensor *dqRopeOut, const aclTensor *dkOut, const aclTensor *dkRopeOut, const aclTensor *dvOut, const aclTensor *dpseOut, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnFlashAttentionUnpaddingScoreGradV4(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, const aclrtStream stream)

功能说明

  • 算子功能:训练场景下计算注意力的反向输出,即aclnnFlashAttentionVarLenScoreV4的反向计算。该接口相较于aclnnFlashAttentionUnpaddingScoreGradV2接口,新增queryRope、keyRope、dqRope和dkRope参数。

  • 计算公式:

    已知注意力的正向计算公式:

    Y=Dropout(Softmax(Mask(QKTd+pse),atten_mask),keep_prob)VY=Dropout(Softmax(Mask(\frac{QK^T}{\sqrt{d}}+pse),atten\_mask),keep\_prob)V

    其中:

    Q=[query,queryRope],K=[key,keyRope]Q=[query, queryRope], K=[key, keyRope]

    为方便表达,以变量SSPP表示计算公式:

    S=Mask(QKTd+pse),atten_maskS=Mask(\frac{QK^T}{\sqrt{d}}+pse),atten\_mask P=Dropout(Softmax(S),keep_prob)P=Dropout(Softmax(S),keep\_prob) Y=PVY=PV

    则注意力的反向计算公式为:

    dV=PTdYdV=P^TdY dQ=((dS)K)ddQ=\frac{((dS)*K)}{\sqrt{d}} dqRope=((dS)kRope)ddqRope=\frac{((dS)*kRope)}{\sqrt{d}} dK=((dS)TQ)ddK=\frac{((dS)^T*Q)}{\sqrt{d}} dkRope=((dS)TqRope)ddkRope=\frac{((dS)^T*qRope)}{\sqrt{d}}

aclnnFlashAttentionUnpaddingScoreGradV4GetWorkspaceSize

  • 参数说明:

    • query(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入Q的nope部分,数据类型支持 BFLOAT16,数据格式支持ND;综合约束请见约束说明

    • queryRope(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入Q的rope部分,即旋转位置编码,数据类型支持 BFLOAT16,数据格式支持ND;综合约束请见约束说明

    • keyIn(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入K的nope部分,数据类型支持 BFLOAT16,数据格式支持ND;综合约束请见约束说明

    • keyInRope(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入K的rope部分,数据类型支持 BFLOAT16,数据格式支持ND;综合约束请见约束说明

    • value(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入V,数据类型支持 BFLOAT16,数据格式支持ND;综合约束请见约束说明

    • dy(aclTensor*,计算输入):Device侧的aclTensor,公式中的输入dY,数据类型支持 BFLOAT16,数据格式支持ND;综合约束请见约束说明

    • pseShiftOptional(aclTensor*,计算输入):公式中的输入pse, 必须为nullptr。

    • dropMaskOptional(aclTensor*,计算输入):必须为nullptr。

    • paddingMaskOptional(aclTensor*,计算输入):Device侧的aclTensor,暂不支持该传参。

    • qStartIdxOptional(aclIntArray*,计算输入):Host侧的aclIntArray,数据类型支持INT64,代表外切场景,当前分块的Q的sequence在全局中的起始索引,数据格式支持ND。

    • kvStartIdxOptional(aclIntArray*,计算输入):Host侧的aclIntArray,数据类型支持INT64,代表外切场景,当前分块的Q的sequence在全局中的起始索引,数据格式支持ND。

    • attenMaskOptional(aclTensor*,计算输入):Device侧的aclTensor,数据类型支持BOOL(8bit的BOOL)、UINT8,数据格式支持ND,支持shape范围为[S1Max,S2Max];综合约束请见约束说明。不可为nullptr。

    • softmaxMaxOptional(aclTensor*,计算输入):Device侧的aclTensor,注意力正向计算的中间输出,数据类型支持FLOAT,数据格式支持ND;综合约束请见约束说明

    • softmaxSumOptional(aclTensor*,计算输入):Device侧的aclTensor,注意力正向计算的中间输出,数据类型支持FLOAT,数据格式支持ND。综合约束请见约束说明

    • softmaxInOptional(aclTensor*,计算输入):Device侧的aclTensor,注意力正向计算的中间输出,预留参数暂未使用,调用时该参数需传空。

    • attentionInOptional(aclTensor*,计算输入):Device侧的aclTensor,注意力正向计算的最终输出,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据类型和shape与query一致,数据格式支持ND。

    • prefixOptional(aclIntArray*,计算输入):Device侧的aclTensor,代表prefix稀疏计算场景每个Batch的N值,数据类型支持INT64,数据格式支持ND;综合约束请见约束说明。如不使用该参数,可传入nullptr。

    • actualSeqQLenOptional(aclIntArray*,计算输入):数据类型支持INT64,数据格式支持ND,描述了每个Batch对应的query S大小;综合约束请见约束说明

    • actualSeqKvLenOptional(aclIntArray*,计算输入):数据类型支持INT64,数据格式支持ND,描述了每个Batch对应的key/value S大小。

    • dqOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的dQ,表示query的梯度,计算输出,数据类型支持 BFLOAT16,数据格式支持ND。

    • dqRopeOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的dqRope,表示queryRope的梯度,计算输出,数据类型支持 BFLOAT16,数据格式支持ND。

    • dkOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的dK,表示keyIn的梯度,计算输出,数据类型支持 BFLOAT16,数据格式支持ND。

    • dkRopeOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的dkRope,表示keyInRope的梯度,计算输出,数据类型支持 BFLOAT16,数据格式支持ND。

    • dvOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的dV,表示value的梯度,计算输出,数据类型支持 BFLOAT16,数据格式支持ND。

    • dpseOut(aclTensor*,计算输出):Device侧的aclTensor,公式中的d(pse),表示pse的梯度,计算输出,数据类型支持FLOAT16、BFLOAT16、FLOAT32,数据格式支持ND,预留参数暂未使用,但在pseShiftOptional不为空时,shape和数据类型与pseShiftOptional一致。

    • scaleValue(double,计算输入):Host侧的double,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE。一般设置为D^-0.5。

    • keepProb(double,计算输入):Host侧的double,代表dropMaskOptional中1的比例,数据类型支持DOUBLE;综合约束请见约束说明。必须设置为1.0。

    • preTokens(int64_t,计算输入):Host侧的int64_t,用于稀疏计算的参数,数据类型支持INT64。用户不特意指定时建议传入2147483647。

    • nextTokens(int64_t,计算输入):Host侧的int64_t,用于稀疏计算的参数,数据类型支持INT64。用户不特意指定时建议传入2147483647。

    • headNum(int64_t,计算输入):Host侧的int64_t,代表head个数,数据类型支持INT64;综合约束请见约束说明

    • inputLayout(string*,计算输入):Host侧的string,代表输入query、keyIn、value的数据排布格式,支持TND。

      说明: query、keyIn、value数据排布格式支持从多种维度解读,其中T (Total S Length) 表示所有batch对应的S的总长、B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。

    • innerPrecise(int32_t,计算输入):保留参数,暂未使用。

    • sparseMode(int64_t,计算输入):Host侧的int,表示sparse的模式,数据类型支持INT64。

      • sparseMode为0时,代表defaultMask模式,如果attenMaskOptional未传入则不做mask操作,忽略preTokens和nextTokens(内部赋值为INT_MAX);如果传入,则需要传入完整的attenMaskOptional矩阵(S1Max * S2Max),表示preTokens和nextTokens之间的部分需要计算。
      • sparseMode为1时,代表allMask,即传入完整的attenMaskOptional矩阵。
      • sparseMode为2时,代表leftUpCausal模式的mask,对应以左顶点为划分的下三角场景,需要传入优化后的attenMaskOptional矩阵(2048*2048)。
      • sparseMode为3时,代表rightDownCausal模式的mask,对应以右下顶点为划分的下三角场景,需要传入优化后的attenMaskOptional矩阵(2048*2048)。
      • sparseMode为4时,代表band场景,即计算preTokens和nextTokens之间的部分。
      • sparseMode为5时,不支持。
      • sparseMode为6时,不支持。
      • sparseMode为7时,代表rightDownCausal_Band场景,该场景由长序列外切产生,需要正确配置preTokens和nextTokens参数;传入shape为[2048, 2048]的下三角attenMaskOptional矩阵。
      • sparseMode为8时,代表band_LeftUpCausal场景,该场景由长序列外切产生,需要正确配置preTokens和nextTokens参数;传入shape为[2048, 2048]的下三角attenMaskOptional矩阵。

      用户不特意指定时建议传入0。sparse不同模式的详细说明请参见sparse模式说明

      说明: 当所有的attenMaskOptional的shape小于2048且相同的时候,建议使用default模式,来减少内存使用量;sparseMode配置为1、2、3、5时,用户配置的preTokens、nextTokens不会生效;sparseMode配置为0、4时,须保证attenMaskOptional与preTokens、nextTokens的范围一致。

    • pseType(int64_t,计算输入):Host侧的int64_t,数据类型支持INT64,用户不特意指定时可传入1,跟当前aclnnFlashAttentionUnpaddingScoreGrad实现一致,支持配置值为0、1,不支持2、3。

      pseType 含义 备注
      0 外部传入pse 先mul再add -
      1 外部传入pse 先add再mul aclnnFlashAttentionUnpaddingScoreGrad实现一致。
    • workspaceSize(uint64_t*,出参):返回用户需要在Device侧申请的workspace大小。

    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。

  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

    第一段接口完成入参校验,若出现以下错误码,则对应原因为:
    - 返回161001(ACLNN_ERR_PARAM_NULLPTR):如果传入参数是必选输入,输出或者必选属性,且是空指针,则返回161001。
    - 返回161002(ACLNN_ERR_PARAM_INVALID):query、queryRope、keyIn、keyInRope、value、dy、pseShiftOptional、dropMaskOptional、paddingMaskOptional、attenMaskOptional、softmaxMaxOptional、softmaxSumOptional、softmaxInOptional、attentionInOptional、dqOut、dkOut、dvOut的数据类型和数据格式不在支持的范围内。

aclnnFlashAttentionUnpaddingScoreGradV4

  • 参数说明:

    • workspace(void*,入参):在Device侧申请的workspace内存地址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由第一段接口aclnnFlashAttentionUnpaddingScoreGradV4GetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束说明

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配

  • 输入query、queryRope、key、keyRope、value、dy的B:batchsize必须相等。

  • 输入query、queryRope、key、keyRope、value、dy的input_layout必须是TND。

  • 在query/key/value的d大小相同的情况下,query/dy的shape必须一致。

  • query/key的d大小必须相同,d必须是8的整数倍。

  • queryRope/keyRope的d大小必须相同,d必须是8的整数倍,且需小于等于query/key的d。

  • 支持输入query/dy的N和key/value的N不相等,但必须成比例关系,即Nq/Nkv必须是非0整数,Nq取值范围1~256。

  • 关于数据shape的约束:

    • B:取值范围为1~2K。带prefixOptional的时候B最大支持1K。
    • N:取值范围为1~256。
    • S:取值范围为1~1M。
    • D:取值范围为1~512。
    • KeepProb:取值范围为1。
  • 部分场景下,如果计算量过大可能会导致算子执行超时(aicore error类型报错,errorStr为:timeout or trap error),此时建议做轴切分处理,注:这里的计算量会受B、S、N、D等参数的影响,值越大计算量越大。

  • prefixOptional稀疏计算仅支持压缩场景,sparseMode=6,当Sq > Skv时,prefix的N值取值范围[0, Skv],当Sq <= Skv时,prefix的N值取值范围[Skv-Sq, Skv]。

  • sparse_mode=7时,不支持可选输入realShiftOptional。

  • sparse_mode=8时,当每个sequence的q、kv等长时支持可选输入realShiftOptional,针对全局做pse生成。支持q方向进行外切,需要外切前每个sequence的q、kv等长,外切后传入的actualSeqQLenOptional[0] - actualSeqKvLenOptional[0] + qStartIdxOptional - kvStartIdxOptional == 0(本功能属实验性功能)。

  • actualSeqQLenOptional输入支持某个Batch上的S长度为0,此时不支持可选输入pseShiftOptional。

  • 关于softmaxMax与softmaxSum参数的约束:输入格式固定为为[T, N, 8],注:T=B*S

  • headNum的取值必须和传入的Query中的N值保持一致。

  • pseType只能为0或者1。

  • pseShiftOptional必须为空。

  • dropMaskOptional必须为空。

  • attenMaskOptional不能为空。

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

  #include <iostream>
  #include <vector>
  #include "acl/acl.h"
  #include "aclnnop/aclnn_flash_attention_score_grad.h"

  #define CHECK_RET(cond, return_expr) \
    do {                               \
      if (!(cond)) {                   \
        return_expr;                   \
      }                                \
    } while (0)

  #define LOG_PRINT(message, ...)     \
    do {                              \
      printf(message, ##__VA_ARGS__); \
    } while (0)

  int64_t GetShapeSize(const std::vector<int64_t>& shape) {
    int64_t shapeSize = 1;
    for (auto i : shape) {
      shapeSize *= i;
    }
    return shapeSize;
  }

  void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) {
    auto size = GetShapeSize(shape);
    std::vector<float> resultData(size, 0);
    auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]),
                           *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return);
    for (int64_t i = 0; i < size; i++) {
      LOG_PRINT("mean result[%ld] is: %f\n", i, resultData[i]);
    }
  }

  int Init(int32_t deviceId, aclrtStream* stream) {
    // 固定写法,AscendCL初始化
    auto ret = aclInit(nullptr);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
    ret = aclrtSetDevice(deviceId);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
    ret = aclrtCreateStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
    return 0;
  }

  template <typename T>
  int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,
                      aclDataType dataType, aclTensor** tensor) {
    auto size = GetShapeSize(shape) * sizeof(T);
    // 调用aclrtMalloc申请device侧内存
    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上
    ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);

    // 计算连续tensor的strides
    std::vector<int64_t> strides(shape.size(), 1);
    for (int64_t i = shape.size() - 2; i >= 0; i--) {
      strides[i] = shape[i + 1] * strides[i + 1];
    }

    // 调用aclCreateTensor接口创建aclTensor
    *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,
                              shape.data(), shape.size(), *deviceAddr);
    return 0;
  }

  int main() {
    // 1. (固定写法)device/stream初始化,参考AscendCL对外接口列表
    // 根据自己的实际device填写deviceId
    int32_t deviceId = 0;
    aclrtStream stream;
    auto ret = Init(deviceId, &stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);

    // 2. 构造输入与输出,需要根据API的接口自定义构造
    std::vector<int64_t> qShape = {256, 1, 128};
    std::vector<int64_t> qRopeShape = {256, 1, 64};
    std::vector<int64_t> kShape = {256, 1, 128};
    std::vector<int64_t> kRopeShape = {256, 1, 64};
    std::vector<int64_t> vShape = {256, 1, 128};
    std::vector<int64_t> dxShape = {256, 1, 128};
    std::vector<int64_t> attenmaskShape = {256, 256};
    std::vector<int64_t> softmaxMaxShape = {256, 1, 8};
    std::vector<int64_t> softmaxSumShape = {256, 1, 8};
    std::vector<int64_t> attentionInShape = {256, 1, 128};

    std::vector<int64_t> dqShape = {256, 1, 128};
    std::vector<int64_t> dqRopeShape = {256, 1, 64};
    std::vector<int64_t> dkShape = {256, 1, 128};
    std::vector<int64_t> dkRopeShape = {256, 1, 64};
    std::vector<int64_t> dvShape = {256, 1, 128};

    void* qDeviceAddr = nullptr;
    void* qRopeDeviceAddr = nullptr;
    void* kDeviceAddr = nullptr;
    void* kRopeDeviceAddr = nullptr;
    void* vDeviceAddr = nullptr;
    void* dxDeviceAddr = nullptr;
    void* attenmaskDeviceAddr = nullptr;
    void* softmaxMaxDeviceAddr = nullptr;
    void* softmaxSumDeviceAddr = nullptr;
    void* attentionInDeviceAddr = nullptr;
    void* dqDeviceAddr = nullptr;
    void* dqRopeDeviceAddr = nullptr;
    void* dkDeviceAddr = nullptr;
    void* dkRopeDeviceAddr = nullptr;
    void* dvDeviceAddr = nullptr;

    aclTensor* q = nullptr;
    aclTensor* qRope = nullptr;
    aclTensor* k = nullptr;
    aclTensor* kRope = nullptr;
    aclTensor* v = nullptr;
    aclTensor* dx = nullptr;
    aclTensor* pse = nullptr;
    aclTensor* dropMask = nullptr;
    aclTensor* padding = nullptr;
    aclTensor* attenmask = nullptr;
    aclTensor* softmaxMax = nullptr;
    aclTensor* softmaxSum = nullptr;
    aclTensor* softmaxIn = nullptr;
    aclTensor* attentionIn = nullptr;
    aclTensor* dq = nullptr;
    aclTensor* dqRope = nullptr;
    aclTensor* dk = nullptr;
    aclTensor* dkRope = nullptr;
    aclTensor* dv = nullptr;
    aclTensor* dpse = nullptr;

    std::vector<short> qHostData(32768, 1);
    std::vector<short> qRopeHostData(16384, 1);
    std::vector<short> kHostData(32768, 1);
    std::vector<short> kRopeHostData(16384, 1);
    std::vector<short> vHostData(32768, 1);
    std::vector<short> dxHostData(32768, 1);
    std::vector<uint8_t> attenmaskHostData(65536, 0);
    std::vector<float> softmaxMaxHostData(2048, 3.0);
    std::vector<float> softmaxSumHostData(2048, 3.0);
    std::vector<short> attentionInHostData(32768, 1);
    std::vector<short> dqHostData(32768, 0);
    std::vector<short> dqRopeHostData(16384, 0);
    std::vector<short> dkHostData(32768, 0);
    std::vector<short> dkRopeHostData(16384, 0);
    std::vector<short> dvHostData(32768, 0);

    ret = CreateAclTensor(qHostData, qShape, &qDeviceAddr, aclDataType::ACL_BF16, &q);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(qRopeHostData, qRopeShape, &qRopeDeviceAddr, aclDataType::ACL_BF16, &qRope);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(kHostData, kShape, &kDeviceAddr, aclDataType::ACL_BF16, &k);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(kRopeHostData, kRopeShape, &kRopeDeviceAddr, aclDataType::ACL_BF16, &kRope);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(vHostData, vShape, &vDeviceAddr, aclDataType::ACL_BF16, &v);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dxHostData, dxShape, &dxDeviceAddr, aclDataType::ACL_BF16, &dx);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(attenmaskHostData, attenmaskShape, &attenmaskDeviceAddr, aclDataType::ACL_UINT8, &attenmask);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(softmaxMaxHostData, softmaxMaxShape, &softmaxMaxDeviceAddr, aclDataType::ACL_FLOAT, &softmaxMax);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(softmaxSumHostData, softmaxSumShape, &softmaxSumDeviceAddr, aclDataType::ACL_FLOAT, &softmaxSum);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(attentionInHostData, attentionInShape, &attentionInDeviceAddr, aclDataType::ACL_BF16, &attentionIn);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dqHostData, dqShape, &dqDeviceAddr, aclDataType::ACL_BF16, &dq);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dqRopeHostData, dqRopeShape, &dqRopeDeviceAddr, aclDataType::ACL_BF16, &dqRope);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dkHostData, dkShape, &dkDeviceAddr, aclDataType::ACL_BF16, &dk);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dkRopeHostData, dkRopeShape, &dkRopeDeviceAddr, aclDataType::ACL_BF16, &dkRope);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    ret = CreateAclTensor(dvHostData, dvShape, &dvDeviceAddr, aclDataType::ACL_BF16, &dv);
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    std::vector<int64_t> prefixOp = {0};
    aclIntArray* prefix = aclCreateIntArray(prefixOp.data(), 1);
    std::vector<int64_t>  acSeqQLenOp = {256};
    std::vector<int64_t>  acSeqKvLenOp = {256};
    aclIntArray* acSeqQLen = aclCreateIntArray(acSeqQLenOp.data(), acSeqQLenOp.size());
    aclIntArray* acSeqKvLen = aclCreateIntArray(acSeqKvLenOp.data(), acSeqKvLenOp.size());
    std::vector<int64_t> qStartIdxOp = {0};
    std::vector<int64_t> kvStartIdxOp = {0};
    aclIntArray *qStartIdx = aclCreateIntArray(qStartIdxOp.data(), 1);
    aclIntArray *kvStartIdx = aclCreateIntArray(kvStartIdxOp.data(), 1);
    double scaleValue = 0.088388;
    double keepProb = 1;
    int64_t preTokens = 65536;
    int64_t nextTokens = 65536;
    int64_t headNum = 1;
    int64_t innerPrecise = 0;
    int64_t sparseMod = 0;
    int64_t pseType = 1;
    char layOut[5] = {'T', 'N', 'D', 0};

    // 3. 调用CANN算子库API,需要修改为具体的API名称
    uint64_t workspaceSize = 0;
    aclOpExecutor* executor;

    // 调用aclnnFlashAttentionUnpaddingScoreGradV4第一段接口
    ret = aclnnFlashAttentionUnpaddingScoreGradV4GetWorkspaceSize(q, qRope, k, kRope, v, dx, pse, dropMask, padding,
              attenmask, softmaxMax, softmaxSum, softmaxIn, attentionIn, prefix, acSeqQLen, acSeqKvLen, qStartIdx, kvStartIdx,
              scaleValue, keepProb, preTokens, nextTokens, headNum, layOut, innerPrecise, sparseMod, pseType,
              dq, dqRope, dk, dkRope, dv, dpse, &workspaceSize, &executor);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnFlashAttentionUnpaddingScoreGradV4GetWorkspaceSize failed. ERROR: %d\n", ret); return ret);

    // 根据第一段接口计算出的workspaceSize申请device内存
    void* workspaceAddr = nullptr;
    if (workspaceSize > 0) {
      ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
      CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);
    }

    // 调用aclnnFlashAttentionUnpaddingScoreGradV4第二段接口
    ret = aclnnFlashAttentionUnpaddingScoreGradV4(workspaceAddr, workspaceSize, executor, stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnFlashAttentionUnpaddingScoreGradV4 failed. ERROR: %d\n", ret); return ret);

    // 4. (固定写法)同步等待任务执行结束
    ret = aclrtSynchronizeStream(stream);
    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

    // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改
    PrintOutResult(dqShape, &dqDeviceAddr);
    PrintOutResult(dqRopeShape, &dqRopeDeviceAddr);
    PrintOutResult(dkShape, &dkDeviceAddr);
    PrintOutResult(dkRopeShape, &dkRopeDeviceAddr);
    PrintOutResult(dvShape, &dvDeviceAddr);

    // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改
    aclDestroyTensor(q);
    aclDestroyTensor(qRope);
    aclDestroyTensor(k);
    aclDestroyTensor(kRope);
    aclDestroyTensor(v);
    aclDestroyTensor(dx);
    aclDestroyTensor(attenmask);
    aclDestroyTensor(softmaxMax);
    aclDestroyTensor(softmaxSum);
    aclDestroyTensor(attentionIn);
    aclDestroyTensor(dq);
    aclDestroyTensor(dqRope);
    aclDestroyTensor(dk);
    aclDestroyTensor(dkRope);
    aclDestroyTensor(dv);

    // 7. 释放device资源
    aclrtFree(qDeviceAddr);
    aclrtFree(qRopeDeviceAddr);
    aclrtFree(kDeviceAddr);
    aclrtFree(kRopeDeviceAddr);
    aclrtFree(vDeviceAddr);
    aclrtFree(dxDeviceAddr);
    aclrtFree(attenmaskDeviceAddr);
    aclrtFree(softmaxMaxDeviceAddr);
    aclrtFree(softmaxSumDeviceAddr);
    aclrtFree(attentionInDeviceAddr);
    aclrtFree(dqDeviceAddr);
    aclrtFree(dqRopeDeviceAddr);
    aclrtFree(dkDeviceAddr);
    aclrtFree(dkRopeDeviceAddr);
    aclrtFree(dvDeviceAddr);
    if (workspaceSize > 0) {
      aclrtFree(workspaceAddr);
    }
    aclrtDestroyStream(stream);
    aclrtResetDevice(deviceId);
    aclFinalize();

    return 0;
  }