SoftmaxFlash增强版本,对应FlashAttention-2算法。当输入shape为ND格式时,内部的reduce过程按last轴进行;当输入shape为NZ格式时,内部的reduce过程按照last轴和first轴进行,reduce过程可以参考SoftMax中的图示说明。
为方便理解,通过python脚本实现的方式,表达其计算公式如下,其中src、inmax、 insum、update为输入,dst、x_sum、x_max、exp_max为输出。
def softmax_flash_2(src, inmax=None, insum=None, update=None): if update == None: x_max = np.max(src, axis=-1, keepdims=True) x_sub = src - x_max dst = np.exp(x_sub) x_sum = np.sum(dst, axis=-1, keepdims=True) exp_max = None return dst, x_max, x_sum, exp_max else: x_max = np.max(np.concatenate((inmax, src), axis=-1), axis=-1, keepdims=True) x_exp = np.exp(src - x_max) exp_max = np.exp(inmax - x_max) x_sum = np.sum(x_exp, axis=-1, keepdims=True) x_sum = exp_max * insum + x_sum return x_exp, x_max, x_sum, exp_max
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor,
const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor,
const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
template <typename T, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false,
bool isDataFormatNZ = false>
__aicore__ inline void SoftmaxFlashV2(const LocalTensor<T>& dstTensor, const LocalTensor<T>& expSumTensor,
const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor,
const LocalTensor<T>& inExpSumTensor, const LocalTensor<T>& inMaxTensor,
const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling,
const SoftMaxShapeInfo& softmaxShapeInfo = {})
由于该接口的内部实现中涉及复杂的计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。
接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间。临时空间大小BufferSize的获取方式如下:通过SoftmaxFlashV2 Tiling接口中提供的GetSoftMaxFlashV2MinTmpSize/GetSoftMaxFlashV2MaxTmpSize接口获取所需最小和最大临时空间大小,最小空间可以保证功能正确,最大空间用于提升性能。
参数名 |
输入/输出 |
描述 |
---|---|---|
dstTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float dstTensor的shape和源操作数srcTensor一致。 |
sumTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float 用于保存softmax计算过程中reducesum的结果。
|
maxTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float 用于保存softmax计算过程中reducemax的结果。
|
srcTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float last轴长度需要32Byte对齐。 |
expMaxTensor |
输出 |
目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float
|
inSumTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float softmax计算所需要的sum值。
|
inMaxTensor |
输入 |
源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float softmax计算所需要的max值。
|
tiling |
输入 |
softmaxflashv2接口计算所需tiling信息,Tiling信息的获取请参考SoftmaxFlashV2 Tiling接口。 |
softmaxShapeInfo |
输入 |
srcTensor的shape信息。SoftMaxShapeInfo类型,具体定义如下: struct SoftMaxShapeInfo { uint32_t srcM; // 非尾轴长度的乘积 uint32_t srcK; // 尾轴长度,必须32Byte对齐 uint32_t oriSrcM; // 原始非尾轴长度的乘积 uint32_t oriSrcK; // 原始尾轴长度 }; 需要注意,当输入输出的数据格式为NZ格式时,尾轴长度为reduce轴长度即图2中的W0*W1,非尾轴为H0*H1。 |
isUpdate |
输入 |
是否使能update部分中的计算。 |
isReuseSource |
输入 |
dstTensor是否复用srcTensor的空间。 |
isBasicBlock |
输入 |
srcTensor和dstTensor的shape信息和Tiling切分策略满足基本块要求的情况下,可以使能该参数用于提升性能,默认不使能。基本块要求如下:
|
isDataFormatNZ |
输入 |
当前输入输出的数据格式是否为NZ格式,默认数据格式为ND,即默认取值为false。 |
无
Atlas A2训练系列产品
Atlas推理系列产品AI Core
#include "kernel_operator.h namespace AscendC { template <typename T> class KernelSoftmaxFlashV2 { public: __aicore__ inline KernelSoftmaxFlashV2() {} __aicore__ inline void Init(__gm__ uint8_t* srcGm, __gm__ uint8_t* inMaxGm, __gm__ uint8_t* inSumGm, __gm__ uint8_t* dstGm, const SoftMaxTiling& tilingData) { elementNumPerBlk = 32 / sizeof(T); srcGlobal.SetGlobalBuffer((__gm__ T*)srcGm); dstGlobal.SetGlobalBuffer((__gm__ T*)dstGm); maxGlobal.SetGlobalBuffer((__gm__ T*)inMaxGm); sumGlobal.SetGlobalBuffer((__gm__ T*)inSumGm); pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T)); pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk * sizeof(T)); pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk * sizeof(T)); pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk * sizeof(T)); pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T)); tiling = tilingData; } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>(); LocalTensor<T> insumLocal = sumQueue.AllocTensor<T>(); LocalTensor<T> inmaxLocal = maxQueue.AllocTensor<T>(); DataCopy(srcLocal, srcGlobal, height * width); DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk); DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk); inQueueSrc.EnQue(srcLocal); sumQueue.EnQue(insumLocal); maxQueue.EnQue(inmaxLocal); } __aicore__ inline void Compute() { LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>(); LocalTensor<T> insumLocal = sumQueue.DeQue<T>(); LocalTensor<T> inmaxLocal = maxQueue.DeQue<T>(); LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>(); LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>(); SoftMaxShapeInfo srcShape = { height, width, height, width }; SoftmaxFlashV2<T, true>(dstLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, insumLocal, inmaxLocal, tiling, srcShape); outQueueDst.EnQue<T>(dstLocal); maxQueue.FreeTensor(inmaxLocal); sumQueue.FreeTensor(insumLocal); inQueueSrc.FreeTensor(srcLocal); } __aicore__ inline void CopyOut() { LocalTensor<T> dstLocal = outQueueDst.DeQue<T>(); DataCopy(dstGlobal, dstLocal, height * width); outQueueDst.FreeTensor(dstLocal); } private: TPipe pipe; TQue<QuePosition::VECIN, 1> inQueueSrc; TQue<QuePosition::VECIN, 1> maxQueue; TQue<QuePosition::VECIN, 1> sumQueue; TQue<QuePosition::VECIN, 1> expMaxQueue; TQue<QuePosition::VECOUT, 1> outQueueDst; GlobalTensor<T> srcGlobal, dstGlobal; GlobalTensor<T> maxGlobal, sumGlobal; uint32_t elementNumPerBlk = 0; uint32_t width = 64; uint32_t height = 320; SoftMaxTiling tiling; }; } // namespace AscendC extern "C" __global__ __aicore__ void softmax_flashv2_generic_kernel_half(__gm__ uint8_t* srcGm, __gm__ uint8_t* inMaxGm, __gm__ uint8_t* inSumGm, __gm__ uint8_t* dstGm, __gm__ uint8_t* tiling) { GET_TILING_DATA(tilingData, tiling); AscendC::KernelSoftmaxFlashV2<half> op; op.Init(srcGm, inMaxGm, inSumGm, dstGm, tilingData.softmaxTilingData); op.Process(); }