aclnnFlashAttentionScore
Atlas 训练系列产品不支持该算子。
Atlas A2 训练系列产品支持该算子。
接口原型
每个算子分为两段式接口,必须先调用“aclnnFlashAttentionScoreGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnFlashAttentionScore”接口执行计算。
- aclnnStatus aclnnFlashAttentionScoreGetWorkspaceSize(const aclTensor *query, const aclTensor *key, const aclTensor *value, const aclTensor *realShiftOptional, const aclTensor *dropMaskOptional, const aclTensor *paddingMaskOptional, const aclTensor *attenMaskOptional, const aclIntArray *prefixOptional, double scaleValueOptional, double keepProbOptional, int64_t preTockensOptional, int64_t nextTockensOptional, int64_t headNum, char *inputLayout, int64_t innerPreciseOptional, int64_t sparseModeOptional, const aclTensor *softmaxMaxOut, const aclTensor *softmaxSumOut, const aclTensor *softmaxOutOut, const aclTensor *attentionOutOut, uint64_t *workspaceSize, aclOpExecutor **executor)
- aclnnStatus aclnnFlashAttentionScore(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, const aclrtStream stream)
aclnnFlashAttentionScoreGetWorkspaceSize
- 参数说明:
- query((aclTensor*,计算输入):即公式中的输入Q。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与key/value的数据类型一致。数据格式支持ND。
- key(aclTensor*,计算输入):即公式中的输入K。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query/value的数据类型一致。数据格式支持ND。
- value(aclTensor*,计算输入):即公式中的输入V。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query/key的数据类型一致。数据格式支持ND。
- realShiftOptional(aclTensor*,计算输入):即公式中的输入pse。数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query的数据类型一致。数据格式支持ND。
- dropMaskOptional(aclTensor*,计算输入):数据类型支持:UINT8。数据格式支持ND。
- paddingMaskOptional(aclTensor*,计算输入):保留输入,暂未使用。
- attenMaskOptional(aclTensor*,计算输入):数据类型支持:BOOL。数据格式支持ND。
- prefixOptional(aclIntArray*,计算输入):数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型与query/key/value的数据类型一致。数据格式支持ND。
- scaleValueOptional(double,计算输入):公式中d开根号的倒数,数据类型支持:DOUBLE,代表缩放系数,作为计算流中Muls的scalar值。
- keepProbOptional(double,计算输入):数据类型支持:DOUBLE,代表dropMask中1的比例。
- preTockensOptional(int64_t,计算输入):数据类型支持:INT64,用于稀疏计算 ,表示slides window的左边界。
- nextTockensOptional(int64_t,计算输入):数据类型支持:INT64,用于稀疏计算,表示slides window的右边界。
- headNum(int64_t,计算输入):数据类型支持:INT64,代表单卡的head个数。
- inputLayout(char*,计算输入):数据类型支持:String,代表输入query、key、value的数据排布格式,支持BSH、SBH、BSND、BNSD。
query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
- innerPreciseOptional(int64_t,计算输入):保留参数,暂未使用。
- sparseModeOptional(int64_t,计算输入):数据类型支持:INT64。用户不特意指定时可传入默认值:0。
- sparseMode为0时,代表defaultMask模式,如果attenmask未传入则不做mask操作,忽略preTokens和nextTokens(内部赋值为INT_MAX);如果传入,则需要传入完整的attenmask矩阵(S1 * S2),表示preTokens和nextTokens之间的部分需要计算。
- sparseMode为为1时,代表allMask,即传入完整的attenmask矩阵。
- sparseMode为2时,代表leftUpCausal模式的mask,对应以左上顶点为划分的下三角场景,需要传入shape为[2048, 2048]的下三角attenmask矩阵。
- sparseMode为3时,代表rightDownCausal模式的mask,对应以右下顶点为划分的下三角场景,需要传入shape为[2048, 2048]的下三角attenmask矩阵。
- sparseMode为为4时,代表band场景,即计算preTokens和nextTokens之间的部分。
- sparseMode为为5时,代表prefix场景,即在rightDownCasual的基础上,左侧加上一个长为S1,宽为N的矩阵,N的值由新增的输入prefix获取,且每个Batch轴的N值不一样。
- sparseMode为为6、7、8时,分别代表global、dilated、block_local,均暂不支持。
注意:当所有的attenmask的shape小于2048且相同的时候,建议使用default模式,来减少内存使用量;sparse_mode配置为2或3时,不能配置preTokens、nextTokens。
- softmaxMaxOut(aclTensor*,计算输出):数据类型支持:FLOAT。数据格式支持ND。
- softmaxSumOut(aclTensor*,计算输出):数据类型支持:FLOAT。数据格式支持ND。
- softmaxOutOut(aclTensor*,计算输出):保留输出,暂未使用。
- attentionOutOut(aclTensor*,计算输出):数据类型支持:FLOAT、FLOAT16、BFLOAT16。数据类型query的数据类型一致。数据格式支持ND。
- workspaceSize(uint64_t*,出参):返回需要在npu device侧申请的workspace大小。
- executor(aclOpExecutor**,出参):返回op执行器,包含了算子计算流程。
- 返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
第一段接口完成入参校验,若出现以下错误码,则对应原因为:
- 返回161001(ACLNN_ERR_PARAM_NULLPTR):如果传入参数是必选输入、输出或者必选属性,且是空指针,则返回161001。
- 返回161002(ACLNN_ERR_PARAM_INVALID):query、key、value、realShift、dropMask、paddingMask、attenMask、softmaxMax、softmaxSum、softmaxOut、attentionOut的数据类型和数据格式不在支持的范围内。
aclnnFlashAttentionScore
- 参数说明:
- workspace(void*,入参):在Device侧申请的workspace内存起址。
- workspaceSize(uint64_t,入参):在Device侧申请的workspace大小,由aclnnFlashAttentionScoreGetWorkspaceSize获取。
- executor(aclOpExecutor*,入参):op执行器,包含了算子计算流程。
- stream(aclrtStream,入参):指定执行任务的AscendCL stream流。
- 返回值:
返回aclnnStatus状态码,具体参见aclnn返回码。
约束与限制
关于数据shape的约束,以inputLayout的BSND、BNSD为例(BSH、SBH下H=N*D),其中:
- B:取值范围为1~256。
- N:取值范围为1~256。
- S:取值范围为1~32K,且为16的倍数。
- D:取值为64、96、128、256。
父主题: 融合类算子接口