SoftmaxFlashV2
产品支持情况
|
产品 |
是否支持 |
|---|---|
|
Atlas 350 加速卡 |
√ |
|
|
√ |
|
|
√ |
|
|
√ |
|
|
√ |
|
|
x |
|
|
x |
功能说明
SoftmaxFlash增强版本,对应FlashAttention-2算法。将输入tensor[m0, m1, ...mt, n](t大于等于0)的非尾轴长度相乘的结果看作m,则输入tensor的shape看作[m, n]。对输入tensor[m,n]按行做如下计算,不同的update值对应不同的计算公式,其中x、inmax和insum为输入,M、S、E均为输出。
当输入shape为ND格式时,内部的reduce过程按last轴进行;当输入shape为NZ格式时,内部的reduce过程按照last轴和first轴进行,reduce过程可以参考SoftMax中的图示说明。
为方便理解,通过Python脚本实现的方式,表达其计算公式如下,其中src、inmax、 insum、update为输入,dst、x_sum、x_max、exp_max为输出。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
def softmax_flash_2(src, inmax=None, insum=None, update=None): if update == None: x_max = np.max(src, axis=-1, keepdims=True) x_sub = src - x_max dst = np.exp(x_sub) x_sum = np.sum(dst, axis=-1, keepdims=True) exp_max = None return dst, x_max, x_sum, exp_max else: x_max = np.max(np.concatenate((inmax, src), axis=-1), axis=-1, keepdims=True) dst = np.exp(src - x_max) exp_max = np.exp(inmax - x_max) x_sum = np.sum(dst, axis=-1, keepdims=True) x_sum = exp_max * insum + x_sum return dst, x_max, x_sum, exp_max |
实现原理
以float类型,ND格式,shape为[m, k]的输入Tensor为例,描述SoftmaxFlashV2高阶API内部算法框图,如下图所示。
计算过程根据isUpdate是否使能分为两个分支处理,均在Vector上进行。
- 当isUpdate为False时,分为如下几步:
- reducemax步骤:对输入x的每一行数据求最大值得到[m, 1],计算结果会保存到一个临时空间temp中;
- broadcast步骤:对temp中的数据[m, 1]做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],同时输出max;
- sub步骤:对输入x的所有数据按行减去max;
- exp步骤:对sub之后的所有数据求exp,并且输出y;
- reducesum步骤:对exp结果的每一行数据求和得到[m, 1],计算结果会保存到临时空间temp中;
- broadcast步骤:对temp[m, 1]做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],同时输出sum。
- 当isUpdate为True时,分为如下几步:
- reducemax步骤:对输入x的每一行数据求最大值得到[m, 1],计算结果会保存到一个临时空间temp中;
- broadcast步骤:对temp中的数据[m, 1]做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],保存为max;
- max步骤:对输入inmax和上一步计算的max做max操作,得到新的max并输出;
- sub步骤:将输入inmax和新的max相减,然后做exp,计算得到expmax并输出;
- sub步骤:将输入x和新的max按行相减;
- exp步骤:对sub之后的所有数据求exp,并且输出y;
- reducesum步骤:对exp结果的每一行数据求和得到[m, 1],计算结果会保存到临时空间temp中;
- broadcast步骤:对temp数据[m, 1]做一个按datablock为单位的填充,比如float类型下,把[m, 1]扩展成[m, 8],保存到sum中;
- mul步骤:将输入insum和expmax结果相乘;
- add步骤:将相乘结果和sum相加,保存到sum并输出。
函数原型
- 接口框架申请临时空间
- LocalTensor的数据类型相同,不输出ReduceMax
1 2
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
- LocalTensor的数据类型相同,且输出ReduceMax
1 2
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outReduceMax, const LocalTensor<T>& outExpSum, const LocalTensor<T>& outMax, const LocalTensor<T>& srcTensor, const LocalTensor<T>& outExpMax, const LocalTensor<T>& inExpSum, const LocalTensor<T>& inMax, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
Atlas 200I/500 A2 推理产品 不支持该接口。Atlas 推理系列产品 AI Core不支持该接口。
- LocalTensor的数据类型不同,不输出ReduceMax
1 2
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
- LocalTensor的数据类型相同,不输出ReduceMax
- 通过sharedTmpBuffer入参传入临时空间
- LocalTensor的数据类型相同,不输出ReduceMax
1 2
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outExpSum, const LocalTensor<T>& outMax, const LocalTensor<T>& srcTensor, const LocalTensor<T>& outExpMax, const LocalTensor<T>& inExpSum, const LocalTensor<T>& inMax, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
- LocalTensor的数据类型相同,且输出ReduceMax
1 2
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& outReduceMax, const LocalTensor<T>& expSumTensor, const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
Atlas 200I/500 A2 推理产品 不支持该接口。Atlas 推理系列产品 AI Core不支持该接口。
- LocalTensor的数据类型不同,不输出ReduceMax
1 2
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV2(const LocalTensor<half>& dstTensor, const LocalTensor<float>& expSumTensor, const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, const LocalTensor<half>& expMaxTensor, const LocalTensor<float>& inExpSumTensor, const LocalTensor<float>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})
- LocalTensor的数据类型相同,不输出ReduceMax
由于该接口的内部实现中涉及复杂的计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。
- 接口框架申请临时空间,开发者无需申请,但是需要预留临时空间的大小。
- 通过sharedTmpBuffer入参传入,使用该tensor作为临时空间进行处理,接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间,并在接口调用完成后,复用该部分内存,内存不会反复申请释放,灵活性较高,内存利用率也较高。
接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间。临时空间大小BufferSize的获取方式如下:通过SoftmaxFlashV2 Tiling接口中提供的GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize接口获取所需最小和最大临时空间大小,最小空间可以保证功能正确,最大空间用于提升性能。
另外提供了一个kernel侧tiling计算的接口,当kernel侧的输入shape与通过host侧TilingData传入的shape不一致时,可使用该接口在kernel侧重新计算tiling。该接口的参数含义请参考SoftmaxFlashV2 Tiling接口。
- kernel侧tiling计算接口
1__aicore__ inline constexpr SoftMaxTiling SoftMaxFlashV2TilingFunc(const SoftMaxShapeInfo& shapeInfo, const uint32_t dataTypeSize1, const uint32_t dataTypeSize2, const uint32_t localWorkSpaceSize, const bool isUpdate = false, const bool isBasicBlock = false, const bool isDataFormatNZ = false, const bool isFlashOutputBrc = false)
参数说明
|
参数名 |
描述 |
||||
|---|---|---|---|---|---|
|
T |
操作数的数据类型。 Atlas 350 加速卡,支持的数据类型为:half、float。 |
||||
|
isUpdate |
是否使能update部分中的计算。 |
||||
|
isReuseSource |
该参数预留,传入默认值false即可。 |
||||
|
isBasicBlock |
srcTensor和dstTensor的shape信息和Tiling切分策略满足基本块要求的情况下,可以使能该参数用于提升性能,默认不使能。是否满足基本块的要求,可以采用如下两种方式之一判断:
针对 |
||||
|
isDataFormatNZ |
当前输入输出的数据格式是否为NZ格式,默认数据格式为ND,即默认取值为false。 针对 |
||||
|
config |
结构体模板参数,此参数可选配,SoftmaxConfig类型,具体定义如下。
其中,参数mode表示输出shape的处理模式,当输入输出的数据格式为NZ格式时,不支持配置mode参数。SoftmaxMode类型,取值如下:
配置示例如下。
此参数一般用于配合kernel侧tiling计算的接口使用。 注意:设置了oriSrcM与oriSrcK后,模板参数isBasicBlock不生效,计算数据是否为基本块由API内部判断并处理。 针对Atlas 350 加速卡,支持该参数。 针对 |
|
参数名 |
输入/输出 |
描述 |
||
|---|---|---|---|---|
|
dstTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 dstTensor的shape和源操作数srcTensor一致。 |
||
|
outReduceMax |
输出 |
目的操作数。用于保存softmax计算过程中reducemax第一次计算的结果。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 outReduceMax的shape与目的操作数maxTensor一致。 对于输出该结果的接口:
|
||
|
expSumTensor、outExpSum |
输出 |
目的操作数。用于保存softmax计算过程中reducesum的结果。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
maxTensor、outMax |
输出 |
目的操作数。用于保存softmax计算过程中reducemax的结果。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
srcTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 last轴长度需要32Byte对齐。 |
||
|
expMaxTensor、outExpMax |
输出 |
目的操作数。用于保存inmax与reducemax差值的e的指数幂的结果。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
inExpSumTensor、inExpSum |
输入 |
源操作数。softmax计算所需要的sum值。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
inMaxTensor、inMax |
输入 |
源操作数。softmax计算所需要的max值。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。
|
||
|
sharedTmpBuffer |
输入 |
临时空间。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 该操作数的数据类型固定uint8_t。 接口内部复杂计算时用于存储中间变量,由开发者提供。 临时空间大小BufferSize的获取方式请参考SoftmaxFlashV2 Tiling接口。 |
||
|
tiling |
输入 |
softmaxflashv2接口计算所需tiling信息,Tiling信息的获取请参考SoftmaxFlashV2 Tiling接口。 |
||
|
softmaxShapeInfo |
输入 |
srcTensor的shape信息。SoftMaxShapeInfo类型,具体定义如下:
需要注意,当输入输出的数据格式为NZ格式时,尾轴长度为reduce轴长度即图2中的W0*W1,非尾轴为H0*H1。 |
返回值说明
无
约束说明
- srcTensor和dstTensor的Tensor的空间可以复用,maxTensor和inMaxTensor的空间可以复用,expSumTensor和inExpSumTensor的空间可以复用。
- 除模板参数config配置为非拓展模式(SoftmaxMode::SOFTMAX_OUTPUT_WITHOUT_BRC)的场景外,expSumTensor、maxTensor、expMaxTensor、inExpSumTensor、inMaxTensor的Tensor空间,last轴长度必须固定32Byte。
- 对于输出ReduceMax的接口:
- 模板参数isReuseSource、isDataFormatNZ、config.isCheckTiling均为预留参数;
- config.mode只支持配置为非拓展模式SOFTMAX_OUTPUT_WITHOUT_BRC,其配置为SOFTMAX_NORMAL模式时,接口功能不执行,不保存各输出;
- 模板参数isUpdate为false时,outReduceMax不输出;
- 除outReduceMax外,其余每个输出的计算结果与不输出ReduceMax的接口相同。
- 操作数地址对齐要求请参见通用地址对齐约束。
- 不支持sharedTmpBuffer与源操作数和目的操作数地址重叠。
- 当参数softmaxShapeInfo中srcM != oriSrcM 或者 srcK != oriSrcK时,开发者需要对GM上的原始输入(oriSrcM, oriSrcK)在M或K方向补齐数据到(srcM, srcK),补齐的数据会参与部分运算,在输入输出复用的场景下,API的计算结果会覆盖srcTensor中补齐的原始数据,在输入输出不复用的场景下,API的计算结果会覆盖dstTensor中对应srcTensor补齐位置的数据。
调用示例
- srcK对齐
本样例中输入srcTensor和输出dstTensor的shape大小为[320,64],输入inSumTensor、inMaxTensor的shape大小为[320,16],输出expMaxTensor的shape大小为[320,16],数据类型均为half,输入输出的数据排布格式为ND,srcTensor和dstTensor空间不复用,不使能基本块,isUpdate为true。完整算子样例请参考softmaxflashv2算子样例。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// dstLocal: 存放SoftMax计算结果的Tensor // expSumLocal:存放softmax计算过程中reducesum的结果 // maxLocal:存放softmax计算过程中reducemax的结果 // srcLocal:存放SoftMax计算的输入Tensor // expMaxLocal:存放inmax与reducemax差值的e的指数幂的结果 // inExpSumLocal:存放softmax计算所需要的sum值 // inMaxLocal:存放softmax计算所需要的max值 // sharedTmpBuffer: 存放SoftMax计算过程中临时缓存的Tensor // softmaxTiling:存放SoftMax计算所需Tiling信息,可通过SoftMaxFlashV2TilingFunc接口获取 AscendC::SoftMaxShapeInfo softmaxInfo( /* 非尾轴长度的乘积 */ 320, /* 尾轴长度,必须32Bytes对齐 */ 64, /* 原始非尾轴长度的乘积 */ 320, /* 原始尾轴长度 */ 64 ); // 通过sharedTmpBuffer入参传入临时空间,不输出ReduceMax,传入模板参数将shape常量化 AscendC::SoftmaxFlashV2<T, true, false, false, false, static_config>(dstLocal, expSumLocal, maxLocal, srcLocal, expMaxLocal, inExpSumLocal, inMaxLocal, sharedTmpBuffer, tiling, softmaxInfo); // 通过sharedTmpBuffer入参传入临时空间,不输出ReduceMax AscendC::SoftmaxFlashV2<T, true>(dstLocal, expSumLocal, maxLocal, srcLocal, expMaxLocal, inExpSumLocal, inMaxLocal, sharedTmpBuffer, tiling, softmaxInfo); // 接口框架申请临时空间,带sumTensor和maxTensor参数 AscendC::SoftmaxFlashV2<T, true>(dstLocal, expSumLocal, maxLocal, srcLocal, expMaxLocal, inExpSumLocal, inMaxLocal, tiling, softmaxInfo);
结果示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
输入数据(srcLocal): [[-10. -10. -10. ... -9.94 -9.94 -9.94 ] [ -9.94 -9.94 -9.94 ... -9.875 -9.875 -9.875] [ -9.875 -9.875 -9.875 ... -9.81 -9.81 -9.81 ] ... [ 9.81 9.81 9.81 ... 9.875 9.875 9.875] [ 9.875 9.875 9.875 ... 9.94 9.94 9.94 ] [ 9.94 9.94 9.94 ... 10. 10. 10. ]] 输出数据(expSumLocal): [[62.03 62.03 62.03 ... 62.03 62.03 62.03] [62.03 62.03 62.03 ... 62.03 62.03 62.03] [62.03 62.03 62.03 ... 62.03 62.03 62.03] ... [62.03 62.03 62.03 ... 62.03 62.03 62.03] [62.03 62.03 62.03 ... 62.03 62.03 62.03] [62.03 62.03 62.03 ... 62.03 62.03 62.03]] 输出数据(maxLocal): [[-9.94 -9.94 -9.94 ... -9.94 -9.94 -9.94 ] [-9.875 -9.875 -9.875 ... -9.875 -9.875 -9.875] [-9.81 -9.81 -9.81 ... -9.81 -9.81 -9.81 ] ... [ 9.875 9.875 9.875 ... 9.875 9.875 9.875] [ 9.94 9.94 9.94 ... 9.94 9.94 9.94 ] [10. 10. 10. ... 10. 10. 10. ]] 输出数据(dstLocal): [[0.015144 0.015144 0.015144 ... 0.01611 0.01611 0.01611 ] [0.015144 0.015144 0.015144 ... 0.01611 0.01611 0.01611 ] [0.015144 0.015144 0.015144 ... 0.01611 0.01611 0.01611 ] ... [0.015144 0.015144 0.015144 ... 0.01611 0.01611 0.01611 ] [0.015144 0.015144 0.015144 ... 0.01611 0.01611 0.01611 ] [0.015144 0.015144 0.015144 ... 0.01611 0.01611 0.01611 ]]
- srcK非对齐
本示例中srcTensor和输出dstTensor的shape大小为[320,63],数据类型均为half,输入输出的数据排布格式为ND,展示非对齐padding补齐的搬入搬出操作和API调用方式。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#include "kernel_operator.h" // init阶段 height=320, width=63 padWidth = AlignUp(width * sizeof(T), 32) / sizeof(T); // copyin阶段 AscendC::DataCopyExtParams copyParams{static_cast<uint16_t>(height), static_cast<uint32_t>(width * sizeof(T)), 0, 0, 0}; AscendC::DataCopyPadExtParams<T> padParam = {true, 0, static_cast<uint8_t>(padWidth - width), 0}; AscendC::DataCopyPad(srcLocal, srcGlobal, copyParams, padParam); // compute阶段 // 由于发生padding,API调用时shape和原始shape发生了不一致 AscendC::SoftMaxShapeInfo srcShape = {height, padWidth, height, width}; AscendC::SoftmaxFlashV2<T, true>(dstLocal, expSumLocal, maxLocal, srcLocal, expMaxLocal, inExpSumLocal, inMaxLocal, tiling, srcShape); // copyout阶段 AscendC::DataCopyExtParams copyParams{static_cast<uint16_t>(height), static_cast<uint32_t>(width * sizeof(T)), 0, 0, 0}; AscendC::DataCopyPad(dstGlobal, dstLocal, copyParams);
结果示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
输入数据(srcLocal): [[-10. -10. -10. ... -9.94 -9.94 0. ] [ -9.94 -9.94 -9.94 ... -9.875 -9.875 0. ] [ -9.875 -9.875 -9.875 ... -9.81 -9.81 0. ] ... [ 9.81 9.81 9.81 ... 9.875 9.875 0. ] [ 9.875 9.875 9.875 ... 9.94 9.94 0. ] [ 9.94 9.94 9.94 ... 10. 10. 0. ]] 输出数据(expSumLocal): [[61.03 61.03 61.03 ... 61.03 61.03 61.03] [61.03 61.03 61.03 ... 61.03 61.03 61.03] [61.03 61.03 61.03 ... 61.03 61.03 61.03] ... [61.1 61.1 61.1 ... 61.1 61.1 61.1 ] [61.1 61.1 61.1 ... 61.1 61.1 61.1 ] [61.1 61.1 61.1 ... 61.1 61.1 61.1 ]] 输出数据(maxLocal): [[-9.94 -9.94 -9.94 ... -9.94 -9.94 -9.94 ] [-9.875 -9.875 -9.875 ... -9.875 -9.875 -9.875] [-9.81 -9.81 -9.81 ... -9.81 -9.81 -9.81 ] ... [ 9.875 9.875 9.875 ... 9.875 9.875 9.875] [ 9.94 9.94 9.94 ... 9.94 9.94 9.94 ] [10. 10. 10. ... 10. 10. 10. ]] 输出数据(dstLocal): [[0.015396 0.015396 0.015396 ... 0.01639 0.01639 0.01639 ] [0.015396 0.015396 0.015396 ... 0.01639 0.01639 0.01639 ] [0.015396 0.015396 0.015396 ... 0.01639 0.01639 0.01639 ] ... [0.01538 0.01538 0.01538 ... 0.01637 0.01637 0.01637 ] [0.01538 0.01538 0.01538 ... 0.01637 0.01637 0.01637 ] [0.01538 0.01538 0.01538 ... 0.01637 0.01637 0.01637 ]]

