昇腾社区首页
中文
注册

aclnnFlashAttentionScore

Atlas 训练系列产品不支持该算子。

Atlas A2 训练系列产品支持该算子。

接口原型

每个算子分为两段式接口,必须先调用“aclnnFlashAttentionScoreGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnFlashAttentionScore”接口执行计算。

  • aclnnStatus aclnnFlashAttentionScoreGetWorkspaceSize(const aclTensor *query, const aclTensor *key, const aclTensor *value, const aclTensor *realShiftOptional, const aclTensor *dropMaskOptional, const aclTensor *paddingMaskOptional, const aclTensor *attenMaskOptional, const aclIntArray *prefixOptional, double scaleValueOptional, double keepProbOptional, int64_t preTockensOptional, int64_t nextTockensOptional, int64_t headNum, char *inputLayout, int64_t innerPreciseOptional, int64_t sparseModeOptional, const aclTensor *softmaxMaxOut, const aclTensor *softmaxSumOut, const aclTensor *softmaxOutOut, const aclTensor *attentionOutOut, uint64_t *workspaceSize, aclOpExecutor **executor)
  • aclnnStatus aclnnFlashAttentionScore(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, const aclrtStream stream)

功能描述

  • 算子功能:训练场景下,使用FlashAttention算法实现self-attention(自注意力)的计算。
  • 计算公式:

    注意力的正向计算公式如下:

aclnnFlashAttentionScoreGetWorkspaceSize

  • 参数说明:
    • query((aclTensor*,计算输入):即公式中的输入Q。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与key/value的数据类型一致。数据格式支持ND。
    • key(aclTensor*,计算输入):即公式中的输入K。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query/value的数据类型一致。数据格式支持ND。
    • value(aclTensor*,计算输入):即公式中的输入V。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query/key的数据类型一致。数据格式支持ND。
    • realShiftOptional(aclTensor*,计算输入):即公式中的输入pse。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query的数据类型一致。数据格式支持ND。

      alibi位置编码场景下,数据格式每个batch参数不同[B N S S]/每个batch参数相同[1 N S S]。

    • dropMaskOptional(aclTensor*,计算输入):数据类型支持:UINT8。数据格式支持ND。
    • paddingMaskOptional(aclTensor*,计算输入):保留输入,暂未使用。
    • attenMaskOptional(aclTensor*,计算输入):数据类型支持:BOOL。数据格式支持ND。
    • prefixOptional(aclIntArray*,计算输入):数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query/key/value的数据类型一致。数据格式支持ND。
    • scaleValueOptional(double,计算输入):公式中d开根号的倒数,数据类型支持:DOUBLE,代表缩放系数,作为计算流中Muls的scalar值。
    • keepProbOptional(double,计算输入):数据类型支持:DOUBLE,代表dropMask中1的比例。
    • preTockensOptional(int64_t,计算输入):数据类型支持:INT64,用于稀疏计算 ,表示slides window的左边界。
    • nextTockensOptional(int64_t,计算输入):数据类型支持:INT64,用于稀疏计算,表示slides window的右边界。
    • headNum(int64_t,计算输入):数据类型支持:INT64,代表单卡的head个数。
    • inputLayout(char*,计算输入):数据类型支持:String,代表输入query、key、value的数据排布格式,支持BSH、SBH、BSND、BNSD。

      query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。

    • innerPreciseOptional(int64_t,计算输入):保留参数,暂未使用。
    • sparseModeOptional(int64_t,计算输入):数据类型支持:INT64。用户不特意指定时可传入默认值:0。
      • sparseMode为0时,代表defaultMask模式,如果attenmask未传入则不做mask操作,忽略preTokens和nextTokens(内部赋值为INT_MAX);如果传入,则需要传入完整的attenmask矩阵(S1 * S2),表示preTokens和nextTokens之间的部分需要计算。
      • sparseMode为为1时,代表allMask,即传入完整的attenmask矩阵。
      • sparseMode为2时,代表leftUpCausal模式的mask,对应以左上顶点为划分的下三角场景,需要传入shape为[2048, 2048]的下三角attenmask矩阵。
      • sparseMode为3时,代表rightDownCausal模式的mask,对应以右下顶点为划分的下三角场景,需要传入shape为[2048, 2048]的下三角attenmask矩阵。
      • sparseMode为为4时,代表band场景,即计算preTokens和nextTokens之间的部分。
      • sparseMode为为5时,代表prefix场景,即在rightDownCasual的基础上,左侧加上一个长为S1,宽为N的矩阵,N的值由新增的输入prefix获取,且每个Batch轴的N值不一样。
      • sparseMode为为6、7、8时,分别代表global、dilated、block_local,均暂不支持

      注意:当所有的attenmask的shape小于2048且相同的时候,建议使用default模式,来减少内存使用量;sparse_mode配置为2或3时,不能配置preTokens、nextTokens。

    • softmaxMaxOut(aclTensor*,计算输出):数据类型支持:FLOAT。数据格式支持ND。
    • softmaxSumOut(aclTensor*,计算输出):数据类型支持:FLOAT。数据格式支持ND。
    • softmaxOutOut(aclTensor*,计算输出):保留输出,暂未使用。
    • attentionOutOut(aclTensor*,计算输出):数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型query的数据类型一致。数据格式支持ND。
    • workspaceSize(uint64_t*,出参):返回需要在npu device侧申请的workspace大小。
    • executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

    第一段接口完成入参校验,若出现以下错误码,则对应原因为:

    • 返回161001(ACLNN_ERR_PARAM_NULLPTR):如果传入参数是必选输入、输出或者必选属性,且是空指针,则返回161001。
    • 返回161002(ACLNN_ERR_PARAM_INVALID):query、key、value、realShift、dropMask、paddingMask、attenMask、softmaxMax、softmaxSum、softmaxOut、attentionOut的数据类型和数据格式不在支持的范围内。

aclnnFlashAttentionScore

  • 参数说明:
    • workspace(void*,入参):在Device侧申请的workspace内存起址。
    • workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由aclnnFlashAttentionScoreGetWorkspaceSize获取。
    • executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
    • stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码

约束与限制

关于数据shape的约束,以inputLayout的BSND、BNSD为例(BSH、SBH下H=N*D),其中:

  • B:取值范围为1~256。
  • N:取值范围为1~256。
  • S:取值范围为1~32K,且为16的倍数。
  • D:取值为64、96、128、256。