Atlas 训练系列产品不支持该算子。
Atlas A2训练系列产品支持该算子。
每个算子有两段接口,必须先调用“aclnnFlashAttentionScoreGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnFlashAttentionScore”接口执行计算。两段式接口如下:
aclnnStatus aclnnFlashAttentionScoreGetWorkspaceSize(const aclTensor *query, const aclTensor *key, const aclTensor *value, const aclTensor *realShift, const aclTensor *dropMask, const aclTensor *paddingMask, const aclTensor *attenMask, double scaleValue, double keepProb, int64_t preTockens, int64_t nextTockens, int64_t headNum, string *inputLayout, int32_t innerPrecise, const aclTensor *softmaxMax, const aclTensor *softmaxSum, const aclTensor *softmaxOut, const aclTensor *attentionOut, uint64_t *workspaceSize, aclOpExecutor **executor)
query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
返回aclnnStatus状态码,具体参见aclnn返回码。
第一段接口完成入参校验,出现以下场景时报错:
aclnnStatus aclnnFlashAttentionScore(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
返回aclnnStatus状态码,具体参见aclnn返回码。
关于数据shape的约束,以inputLayout的BSND、BNSD为例(BSH、SBH下H=N*D),其中