SoftmaxFlashV3
Function Usage
Serves as the enhanced version of SoftmaxFlash, corresponding to the Softmax PASA algorithm. If the product of non-last axis lengths m0, m1, ..., mt of the input tensor [m0, m1, ..., mt, n] (t ≥ 0) is considered as m, the shape of the input tensor is [m, n]. This API performs the following computation on the input tensor [m, n] by row. Different values of update correspond to different formulas, where x, inmax, insum, and inmean are inputs, and M, S, and E are outputs.
- If update is false, the formulas are as follows.

- If update is true, the formulas are as follows.

Currently, this API supports the input in ND format only. The internal reduction process is processed based on the last axis.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | def softmax_flash_3(src, height, width, loopCnt, alpha, baseK, inmax=None, insum=None, inmean=None, update=False): scalar = alapha / (1 - alapha) #(m,n)->(m,64) tmpbuffer0 = BlockReduceSum(repeatSize, repeatSize, elementNumPerBlk) remain = int(width / repeatSize - BlkcntPerRepeat) tmpbuffer0 = Add(tmpbuffer0, src, remain, repeatSize * elementNumPerBlk, width) #(m,64)->(m,8) tmpbuffer0 = BlockReduceSum(1, relementNumPerBlk, elementNumPerBlk) #width = baseK * splitMeanCnt rowMeanLocal = tmpbuffer0 / baseK rowMeanGlobal = np.mean(src, axis=(-1), keepdims=True) rowMeanGlobalTmp = (rowMeanGlobal - rowMeanLocal) * scalar src = src - rowMeanGlobalTmp if update == False: x_mean = rowMeanGlobal maxTmp = np.max(src, axis=-1, keepdims=True) shiftCurr = (rowMeanGlobal - x_mean) * scalar x_max = shiftCurr + maxTmp maxTmp = x_max - shiftCurr x_sub = src - maxTmp dst = np.exp(x_sub) x_sum = np.sum(dst, axis=-1, keepdims=True) exp_max = None return dst, x_max, x_sum, x_mean, exp_max else: x_mean = (rowMeanGlobal + inmean * (loopCnt - 1)) / loopCnt maxTmp = np.max(src, axis=-1, keepdims=True) shiftCurr = (rowMeanGlobal - x_mean) * scalar shiftPrev = (inmean - x_mean) * scalar x_max = shiftCurr + maxTmp maxTmp = shiftPrev + inmax x_max = np.max(np.concatenate((x_max, maxTmp), axis=(-1)), axis=(-1), keepdims=True) maxTmp = x_max - shiftCurr x_sub = src - maxTmp dst = np.exp(x_sub) exp_max = np.exp(inmax - x_max + shiftPrev) x_sum = np.sum(x_exp, axis=-1, keepdims=True) x_sum = exp_max * insum + x_sum return x_exp, x_max, x_sum, x_mean, exp_max |
Prototype
- Allocate the temporary space through the API framework.
1 2
template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor, const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxParams& params)
- Pass the temporary space through the sharedTmpBuffer input parameter.
1 2
template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor,const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor,const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxParams& params)
Due to the complex computation involved in the internal implementation of this API, additional temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.
- When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
- When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.
If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required minimum and maximum temporary space sizes using the GetSoftMaxFlashV3MaxMinTmpSize API described in SoftmaxFlashV3 Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.
Parameters
Parameter |
Description |
|---|---|
T |
Data types of the input srcTensor and expMaxTensor, and the output dstTensor operands. |
U |
Data types of the input inMeanTensor, inExpSumTensor, and inMaxTensor and the output meanTensor, expSumTensor, and maxTensor operands. |
isUpdate |
Whether to set update to true in the computation. |
isReuseSource |
Reserved for future use. The default value false must be used. |
isBasicBlock |
Reserved for future use. The default value false must be used. |
isDataFormatNZ |
Reserved for future use. The default value false must be used. |
config |
Reserved for future use. The default value SOFTMAX_DEFAULT_CFG must be used. |
Parameter |
Input/Output |
Description |
||
|---|---|---|---|---|
dstTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The shape of dstTensor is the same as that of the source operand srcTensor. |
||
meanTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. It is used to store the mean value result during softmax computation.
|
||
expSumTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. It is used to store the reducesum result during softmax computation.
|
||
maxTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. It is used to store the reducemax result during softmax computation.
|
||
srcTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The length of the last axis must be 32-byte aligned. |
||
expMaxTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.
|
||
inMeanTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. Mean value required for softmax computation.
|
||
inExpSumTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. Sum value required for softmax computation.
|
||
inMaxTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. Max value required for softmax computation.
|
||
sharedTmpBuffer |
Input |
Temporary space. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The data type of this operand is fixed at uint8_t. This parameter is used to store intermediate variables during complex internal API computation and is provided by developers. For details about how to obtain the temporary space size (BufferSize), see SoftmaxFlashV3 Tiling. |
||
tiling |
Input |
Tiling information required for SoftmaxFlashV3 computation. For details about how to obtain the tiling information, see SoftmaxFlashV3 Tiling. |
||
params |
Input |
Shape information and computation parameters of srcTensor. It is of the SoftMaxParams type. The specific definition is as follows:
Note: This API does not support the non-alignment scenario. Therefore, srcM is equal to oriSrcM, and srcK is equal to oriSrcK. |
Returns
None
Availability
Precautions
- For details about the alignment requirements of the operand address offset, see General Restrictions.
- For the input srcTensor, the last axis length n must be greater than or equal to 512 and be a multiple of 64. The product m of non-last axis lengths is a multiple of 8.
- The tensor space of srcTensor and dstTensor, meanTensor and inMeanTensor, maxTensor and inMaxTensor, and expSumTensor and inExpSumTensor can be reused.
- For the tensor space of meanTensor, expSumTensor, maxTensor, expMaxTensor, inMeanTensor, or inExpSumTensor, and inMaxTensor, the length of the last axis must be 32 bytes.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | #include "kernel_operator.h" template <typename T, typename U> class KernelSoftmaxFlashV3 { public: __aicore__ inline KernelSoftmaxFlashV3() {} __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData) { srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm); dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm); maxGlobal.SetGlobalBuffer((__gm__ U *)inMaxGm); sumGlobal.SetGlobalBuffer((__gm__ U *)inSumGm); meanGlobal.SetGlobalBuffer((__gm__ U *)inMeanGm); pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T)); elementNumPerBlk1 = 32 / sizeof(U); pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk1 * sizeof(U)); pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk1 * sizeof(U)); pipe.InitBuffer(meanQueue, 1, height * elementNumPerBlk1 * sizeof(U)); elementNumPerBlk2 = 32 / sizeof(T); pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk2 * sizeof(T)); pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T)); tiling = tilingData; } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>(); AscendC::LocalTensor<U> insumLocal = sumQueue.AllocTensor<U>(); AscendC::LocalTensor<U> inmaxLocal = maxQueue.AllocTensor<U>(); AscendC::LocalTensor<U> inmeanLocal = meanQueue.AllocTensor<U>(); AscendC::DataCopy(srcLocal, srcGlobal, height * width); AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk1); AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk1); AscendC::DataCopy(inmeanLocal, meanGlobal, height * elementNumPerBlk1); inQueueSrc.EnQue(srcLocal); sumQueue.EnQue(insumLocal); maxQueue.EnQue(inmaxLocal); meanQueue.EnQue(inmeanLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>(); AscendC::LocalTensor<U> insumLocal = sumQueue.DeQue<U>(); AscendC::LocalTensor<U> inmaxLocal = maxQueue.DeQue<U>(); AscendC::LocalTensor<U> inmeanLocal = meanQueue.DeQue<U>(); AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>(); AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>(); AscendC::SoftMaxParams params = {height, width, height, width, loopCnt, splitMeanCnt, alpha}; AscendC::SoftmaxFlashV3<T, U, true>(dstLocal, inmeanLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, inmeanLocal, insumLocal, inmaxLocal, tiling, params); outQueueDst.EnQue<T>(dstLocal); maxQueue.FreeTensor(inmaxLocal); sumQueue.FreeTensor(insumLocal); meanQueue.FreeTensor(inmeanLocal); inQueueSrc.FreeTensor(srcLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>(); AscendC::DataCopy(dstGlobal, dstLocal, height * width); outQueueDst.FreeTensor(dstLocal); } private: AscendC::TPipe pipe; AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueSrc; AscendC::TQue<AscendC::QuePosition::VECIN, 1> meanQueue; AscendC::TQue<AscendC::QuePosition::VECIN, 1> maxQueue; AscendC::TQue<AscendC::QuePosition::VECIN, 1> sumQueue; AscendC::TQue<AscendC::QuePosition::VECIN, 1> expMaxQueue; AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueDst; AscendC::GlobalTensor<T> srcGlobal, dstGlobal; AscendC::GlobalTensor<U> meanGlobal, maxGlobal, sumGlobal; uint32_t elementNumPerBlk1 = 0; uint32_t elementNumPerBlk2 = 0; uint32_t width = 1024; uint32_t height = 8; uint32_t loopCnt = 2; uint32_t splitMeanCnt = 8; float alpha = 0.9375; SoftMaxTiling tiling; }; extern "C" __global__ __aicore__ void softmax_flashv3_kernel(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling) { GET_TILING_DATA(tilingData, tiling); KernelSoftmaxFlashV3<half, float> op; op.Init(srcGm, inMaxGm, inSumGm, inMeanGm, dstGm, tilingData.softmaxTilingData); op.Process(); } |