SoftmaxFlashV3
Applicability
Product |
Supported |
|---|---|
√ |
|
√ |
|
x |
|
x |
|
x |
|
x |
Function
Serves as the enhanced version of SoftmaxFlash, corresponding to the Softmax PASA algorithm. If the product of non-last axis lengths m0, m1, ..., mt of the input tensors [m0, m1, ..., mt, n] (t ≥ 0) is considered as m, the shape of the input tensor is [m, n]. The last axis of the input tensor x is split. The number of blocks is splitMeanCnt. The tensor after splitting is x_cnti. Below is the formula, where x, inmax, insum, and inmean are inputs, and M, S, E, and A are outputs.
- If update is false, the formulas are as follows.

- If update is true, the formulas are as follows.

Currently, this API supports the input in ND format only. The internal reduce process is processed based on the last axis.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | def softmax_flash_3(src, height, width, loopCnt, alpha, baseK, inmax=None, insum=None, inmean=None, update=False): scalar = alpha / (1 - alpha) #(m,n)->(m,64) tmpbuffer0 = BlockReduceSum(repeatSize, repeatSize, elementNumPerBlk) remain = int(width / repeatSize - BlkcntPerRepeat) tmpbuffer0 = Add(tmpbuffer0, src, remain, repeatSize * elementNumPerBlk, width) #(m,64)->(m,8) tmpbuffer0 = BlockReduceSum(1, elementNumPerBlk, elementNumPerBlk) #width = baseK * splitMeanCnt rowMeanLocal = tmpbuffer0 / baseK rowMeanGlobal = np.mean(src, axis=(-1), keepdims=True) rowMeanGlobalTmp = (rowMeanGlobal - rowMeanLocal) * scalar src = src - rowMeanGlobalTmp if update == False: x_mean = rowMeanGlobal maxTmp = np.max(src, axis=-1, keepdims=True) shiftCurr = (rowMeanGlobal - x_mean) * scalar x_max = shiftCurr + maxTmp maxTmp = x_max - shiftCurr x_sub = src - maxTmp dst = np.exp(x_sub) x_sum = np.sum(dst, axis=-1, keepdims=True) exp_max = None return dst, x_max, x_sum, x_mean, exp_max else: x_mean = (rowMeanGlobal + inmean * (loopCnt - 1)) / loopCnt maxTmp = np.max(src, axis=-1, keepdims=True) shiftCurr = (rowMeanGlobal - x_mean) * scalar shiftPrev = (inmean - x_mean) * scalar x_max = shiftCurr + maxTmp maxTmp = shiftPrev + inmax x_max = np.max(np.concatenate((x_max, maxTmp), axis=(-1)), axis=(-1), keepdims=True) maxTmp = x_max - shiftCurr x_sub = src - maxTmp dst = np.exp(x_sub) exp_max = np.exp(inmax - x_max + shiftPrev) x_sum = np.sum(x_exp, axis=-1, keepdims=True) x_sum = exp_max * insum + x_sum return x_exp, x_max, x_sum, x_mean, exp_max |
Prototype
- Allocate the temporary space through the API framework.
1 2
template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor, const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor, const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const SoftMaxTiling& tiling, const SoftMaxParams& params)
- Pass to the temporary space through the sharedTmpBuffer input parameter.
1 2
template <typename T, typename U, bool isUpdate = false, bool isReuseSource = false, bool isBasicBlock = false, bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG> __aicore__ inline void SoftmaxFlashV3(const LocalTensor<T>& dstTensor, const LocalTensor<U>& meanTensor,const LocalTensor<U>& expSumTensor, const LocalTensor<U>& maxTensor, const LocalTensor<T>& srcTensor,const LocalTensor<T>& expMaxTensor, const LocalTensor<U>& inMeanTensor, const LocalTensor<U>& inExpSumTensor, const LocalTensor<U>& inMaxTensor, const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, const SoftMaxParams& params)
Due to the complex computation involved in the internal implementation of this API, extra temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.
- When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
- When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.
If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required minimum and maximum temporary space sizes using the GetSoftMaxFlashV3MaxMinTmpSize API described in SoftmaxFlashV3 Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.
Parameters
Parameter |
Description |
|---|---|
T |
Data types of the input srcTensor and output dstTensor and expMaxTensor operands. For the For the |
U |
Data types of the input inMeanTensor, inExpSumTensor, and inMaxTensor and the output meanTensor, expSumTensor, and maxTensor operands. For the For the |
isUpdate |
Whether to set update to true in the computation. |
isReuseSource |
This parameter is reserved. Pass the default value false. |
isBasicBlock |
This parameter is reserved. Pass the default value false. |
isDataFormatNZ |
This parameter is reserved. Pass the default value false. |
config |
This parameter is reserved. Use the default value SOFTMAX_DEFAULT_CFG. |
Parameter |
Input/Output |
Description |
||
|---|---|---|---|---|
dstTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The shape of dstTensor is the same as that of the source operand srcTensor. |
||
meanTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. It is used to store the mean value result during softmax computation.
|
||
expSumTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. It is used to store the reducesum result during softmax computation.
|
||
maxTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. It is used to store the reducemax result during softmax computation.
|
||
srcTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The length of the last axis must be 32-byte aligned. |
||
expMaxTensor |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.
|
||
inMeanTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. This parameter indicates the mean value required for softmax computation.
|
||
inExpSumTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. This parameter indicates the sum value required for softmax computation.
|
||
inMaxTensor |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. This parameter indicates the max value required for softmax computation.
|
||
sharedTmpBuffer |
Input |
Temporary space. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The data type of this operand is fixed at uint8_t. This parameter is used to store intermediate variables during complex internal API computation and is provided by developers. For details about how to obtain the temporary space size (BufferSize), see SoftmaxFlashV3 Tiling. |
||
tiling |
Input |
Tiling information required for SoftmaxFlashV3 computation. For details about how to obtain the tiling information, see SoftmaxFlashV3 Tiling. |
||
params |
Input |
Shape information and computation parameters of srcTensor. It is of the SoftMaxParams type. The specific definition is as follows:
Note: This API does not support the non-alignment scenario. Therefore, srcM is equal to oriSrcM, and srcK is equal to oriSrcK. |
Returns
None
Restrictions
- For details about the operand address alignment requirements, see General Address Alignment Restrictions.
- For the input srcTensor, the last axis length n must be greater than or equal to 512 and be a multiple of 64. The product m of non-last axis lengths is a multiple of 8.
- The tensor space of srcTensor and dstTensor, meanTensor and inMeanTensor, maxTensor and inMaxTensor, and expSumTensor and inExpSumTensor can be reused.
- For the tensor space of meanTensor, expSumTensor, maxTensor, expMaxTensor, inMeanTensor, or inExpSumTensor, and inMaxTensor, the length of the last axis must be fixed at 32 bytes.
- The address of sharedTmpBuffer must not overlap that of the source or destination operand.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | #include "kernel_operator.h" template <typename T, typename U> class KernelSoftmaxFlashV3 { public: __aicore__ inline KernelSoftmaxFlashV3() {} __aicore__ inline void Init(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, const SoftMaxTiling &tilingData) { srcGlobal.SetGlobalBuffer((__gm__ T *)srcGm); dstGlobal.SetGlobalBuffer((__gm__ T *)dstGm); maxGlobal.SetGlobalBuffer((__gm__ U *)inMaxGm); sumGlobal.SetGlobalBuffer((__gm__ U *)inSumGm); meanGlobal.SetGlobalBuffer((__gm__ U *)inMeanGm); pipe.InitBuffer(inQueueSrc, 1, height * width * sizeof(T)); elementNumPerBlk1 = 32 / sizeof(U); pipe.InitBuffer(maxQueue, 1, height * elementNumPerBlk1 * sizeof(U)); pipe.InitBuffer(sumQueue, 1, height * elementNumPerBlk1 * sizeof(U)); pipe.InitBuffer(meanQueue, 1, height * elementNumPerBlk1 * sizeof(U)); elementNumPerBlk2 = 32 / sizeof(T); pipe.InitBuffer(expMaxQueue, 1, height * elementNumPerBlk2 * sizeof(T)); pipe.InitBuffer(outQueueDst, 1, height * width * sizeof(T)); tiling = tilingData; } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<T> srcLocal = inQueueSrc.AllocTensor<T>(); AscendC::LocalTensor<U> insumLocal = sumQueue.AllocTensor<U>(); AscendC::LocalTensor<U> inmaxLocal = maxQueue.AllocTensor<U>(); AscendC::LocalTensor<U> inmeanLocal = meanQueue.AllocTensor<U>(); AscendC::DataCopy(srcLocal, srcGlobal, height * width); AscendC::DataCopy(insumLocal, sumGlobal, height * elementNumPerBlk1); AscendC::DataCopy(inmaxLocal, maxGlobal, height * elementNumPerBlk1); AscendC::DataCopy(inmeanLocal, meanGlobal, height * elementNumPerBlk1); inQueueSrc.EnQue(srcLocal); sumQueue.EnQue(insumLocal); maxQueue.EnQue(inmaxLocal); meanQueue.EnQue(inmeanLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<T> srcLocal = inQueueSrc.DeQue<T>(); AscendC::LocalTensor<U> insumLocal = sumQueue.DeQue<U>(); AscendC::LocalTensor<U> inmaxLocal = maxQueue.DeQue<U>(); AscendC::LocalTensor<U> inmeanLocal = meanQueue.DeQue<U>(); AscendC::LocalTensor<T> expMaxTensor = expMaxQueue.AllocTensor<T>(); AscendC::LocalTensor<T> dstLocal = outQueueDst.AllocTensor<T>(); AscendC::SoftMaxParams params = {height, width, height, width, loopCnt, splitMeanCnt, alpha}; AscendC::SoftmaxFlashV3<T, U, true>(dstLocal, inmeanLocal, insumLocal, inmaxLocal, srcLocal, expMaxTensor, inmeanLocal, insumLocal, inmaxLocal, tiling, params); outQueueDst.EnQue<T>(dstLocal); maxQueue.FreeTensor(inmaxLocal); sumQueue.FreeTensor(insumLocal); meanQueue.FreeTensor(inmeanLocal); inQueueSrc.FreeTensor(srcLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<T> dstLocal = outQueueDst.DeQue<T>(); AscendC::DataCopy(dstGlobal, dstLocal, height * width); outQueueDst.FreeTensor(dstLocal); } private: AscendC::TPipe pipe; AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueSrc; AscendC::TQue<AscendC::TPosition::VECIN, 1> meanQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> maxQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> sumQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> expMaxQueue; AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueDst; AscendC::GlobalTensor<T> srcGlobal, dstGlobal; AscendC::GlobalTensor<U> meanGlobal, maxGlobal, sumGlobal; uint32_t elementNumPerBlk1 = 0; uint32_t elementNumPerBlk2 = 0; uint32_t width = 1024; uint32_t height = 8; uint32_t loopCnt = 2; uint32_t splitMeanCnt = 8; float alpha = 0.9375; SoftMaxTiling tiling; }; extern "C" __global__ __aicore__ void softmax_flashv3_kernel(__gm__ uint8_t *srcGm, __gm__ uint8_t *inMaxGm, __gm__ uint8_t *inSumGm, __gm__ uint8_t *inMeanGm, __gm__ uint8_t *dstGm, __gm__ uint8_t *tiling) { GET_TILING_DATA(tilingData, tiling); KernelSoftmaxFlashV3<half, float> op; op.Init(srcGm, inMaxGm, inSumGm, inMeanGm, dstGm, tilingData.softmaxTilingData); op.Process(); } |