昇腾社区首页
中文
注册

aclnnFlashAttentionScoreGrad

接口原型

每个算子有两段接口,必须先调用“aclnnFlashAttentionScoreGradGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnFlashAttentionScoreGrad”接口执行计算。两段式接口如下:

  • 第一段接口:aclnnStatus aclnnFlashAttentionScoreGradGetWorkspaceSize(const aclTensor *query, const aclTensor *key, const aclTensor *value, const aclTensor *dy, const aclTensor *pseShift, const aclTensor *dropMask, const aclTensor *paddingMask, const aclTensor *attenMask, const aclTensor *softmaxMax, const aclTensor *softmaxSum, const aclTensor *softmaxIn, const aclTensor *attentionIn, const aclTensor *dq, const aclTensor *dk, const aclTensor *dv, const aclTensor *dpse, double scaleValue, double keepProb, int64_t precTockens, int64_t nextTockens, int64_t headNum, string *inputLayout, int32_t innerPrecise, uint64_t *workspaceSize, aclOpExecutor **executor)
  • 第二段接口:aclnnStatus aclnnFlashAttentionScoreGrad(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

功能描述

  • 算子功能:训练场景下计算注意力的反向输出,即aclnnFlashAttentionScore的反向计算。
  • 计算公式:

    已知注意力的正向计算公式为

    为方便表达,以变量S和P表示计算公式:

    那么注意力的反向计算公式为

aclnnFlashAttentionGradGetWorkspaceSize

  • 接口定义:

    aclnnStatus aclnnFlashAttentionScoreGradGetWorkspaceSize(const aclTensor *query, const aclTensor *key, const aclTensor *value, const aclTensor *dy, const aclTensor *pseShift, const aclTensor *dropMask, const aclTensor *paddingMask, const aclTensor *attenMask, const aclTensor *softmaxMax, const aclTensor *softmaxSum, const aclTensor *softmaxIn, const aclTensor *attentionIn, const aclTensor *dq, const aclTensor *dk, const aclTensor *dv, const aclTensor *dpse,double scaleValue, double keepProb, int64_t preTockens, int64_t nextTockens, int64_t headNum, string *inputLayout, int32_t innerPrecise, uint64_t *workspaceSize, aclOpExecutor **executor)

  • 参数说明:
    • query:Device侧的aclTensor,公式中的输入Q,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • key:Device侧的aclTensor,公式中的输入K,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • value:Device侧的aclTensor,公式中的输入V,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • dy:Device侧的aclTensor,公式中的输入dY,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • pseShift:Device侧的aclTensor,公式中的输入pse,可选参数,表示位置编码。数据类型支持FLOAT、FLOAT16,数据格式支持ND。
    • dropMask:Device侧的aclTensor,可选属性,数据类型支持FLOAT、FLOAT16,数据格式支持ND。
    • paddingMask:Device侧的aclTensor,可选属性,数据类型支持FLOAT、FLOAT16,数据格式支持ND。
    • attenMask:Device侧的aclTensor,可选属性,代表下三角全为0上三角全为负无穷的倒三角mask矩阵,数据类型支持BOOL,数据格式支持ND。
    • softmaxMax:Device侧的aclTensor,注意力正向计算的中间输出,数据类型支持FLOAT、FLOAT16,数据格式支持ND。
    • softmaxSum:Device侧的aclTensor,注意力正向计算的中间输出,数据类型支持FLOAT、FLOAT16,数据格式支持ND。
    • softmaxIn:Device侧的aclTensor,注意力正向计算的中间输出,数据类型支持FLOAT、FLOAT16,数据格式支持ND。
    • attentionIn:Device侧的aclTensor,注意力正向计算的最终输出attentionOut,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • dq:Device侧的aclTensor,公式中的dQ,表示query的梯度,计算输出,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • dk:Device侧的aclTensor,公式中的dK,表示key的梯度,计算输出,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • dv:Device侧的aclTensor,公式中的dV,表示value的梯度,计算输出,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • dpse:Device侧的aclTensor,公式中的d(pse),表示pse的梯度,计算输出,数据类型支持FLOAT、FLOAT16、BFLOAT16,数据格式支持ND。
    • scaleValue:Host侧的double,公式中d开根号的倒数,代表缩放系数,作为计算流中Muls的scalar值,数据类型支持DOUBLE。
    • keepProb:Host侧的double,可选参数,代表dropMask中1的比例,数据类型支持FLOAT32。
    • preTockens:Host侧的int64_t,用于稀疏计算的参数,可选参数,数据类型支持INT64。
    • nextTockens:Host侧的int64_t,用于稀疏计算的参数,可选参数,数据类型支持INT64。
    • headNum:Host侧的int64_t,代表head个数,数据类型支持INT64。
    • inputLayout:Host侧的string,代表输入query、key、value的数据排布格式,支持BSH、SBH、BSND、BNSD。

      query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。

    • innerPrecise:Host侧的int32_t,数据类型支持INT32,内部计算精度模式,其中0表示为高精度,1表示为高性能。
    • workspaceSize:返回用户需要在Device侧申请的workspace大小。
    • executor:返回op执行器,包含了算子计算流程。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

    第一段接口完成入参校验,出现以下场景时报错:

    • 返回161001(ACLNN_ERR_PARAM_NULLPTR):传入的query、key、value、dy、dq、dk、dv是空指针。
    • 返回161002(ACLNN_ERR_PARAM_INVALID):query、key、value、dy、pseShift、dropMask、paddingMask、attenMask、softmaxMax、softmaxSum、softmaxIn、attentionIn、dq、dk、dv的数据类型和数据格式不在支持的范围内。

aclnnFlashAttentionScoreGrad

  • 接口定义:

    aclnnStatus aclnnFlashAttentionScoreGrad(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

  • 参数说明:
    • workspace:在Device侧申请的workspace内存起址。
    • workspaceSize:在Device侧申请的workspace大小,由第一段接口aclnnFlashAttentionScoreGradGetWorkspaceSize获取。
    • executor:op执行器,包含了算子计算流程。
    • stream:指定执行任务的AscendCL stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束与限制

关于数据shape的约束,以inputLayout的BSND、BNSD为例(BSH、SBH下H=N*D),其中

  • B:取值范围为1~256。
  • N:取值范围为1~256。
  • S:取值范围为1~32K,且为16的倍数。
  • D:取值为64、96、128、256。