Atlas 训练系列产品不支持该算子。
Atlas A2 训练系列产品支持该算子。
每个算子分为两段式接口,必须先调用“aclnnFlashAttentionScoreGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnFlashAttentionScore”接口执行计算。
query、key、value数据排布格式支持从多种维度解读,其中B(Batch)表示输入样本批量大小、S(Seq-Length)表示输入样本序列长度、H(Head-Size)表示隐藏层的大小、N(Head-Num)表示多头数、D(Head-Dim)表示隐藏层最小的单元尺寸,且满足D=H/N。
注意:当所有的attenmask的shape小于2048且相同的时候,建议使用default模式,来减少内存使用量;sparse_mode配置为2或3时,不能配置preTokens、nextTokens。
返回aclnnStatus状态码,具体参见aclnn返回码。
第一段接口完成入参校验,若出现以下错误码,则对应原因为:
返回aclnnStatus状态码,具体参见aclnn返回码。
关于数据shape的约束,以inputLayout的BSND、BNSD为例(BSH、SBH下H=N*D),其中: