RmsNorm
Applicability
Product |
Supported |
|---|---|
√ |
|
√ |
|
x |
|
√ |
|
x |
|
x |
Function
Normalizes input data whose shape is [B, S, H] using RmsNorm. Below is the formula.

γ is the scale coefficient, and ε is the weight coefficient for preventing division by zero.
Prototype
- Pass to the temporary space through the sharedTmpBuffer input parameter.
1 2
template <typename T, bool isBasicBlock = false> __aicore__ inline void RmsNorm(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<T>& gammaLocal, const LocalTensor<uint8_t>& sharedTmpBuffer, const T epsilon, const RmsNormTiling& tiling)
- Allocate the temporary space through the API framework.
1 2
template <typename T, bool isBasicBlock = false> __aicore__ inline void RmsNorm(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<T>& gammaLocal, const T epsilon, const RmsNormTiling& tiling)
Due to the complex computation involved in the internal implementation of this API, extra temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.
- When the API framework is used for temporary space allocation, you do not need to allocate the space, but must reserve the required size for the temporary space.
- When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables you to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.
If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required maximum and minimum temporary space sizes using the GetRmsNormMaxMinTmpSize API provided in RmsNorm Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.
Parameters
Parameter |
Description |
|---|---|
T |
Data type of the operand. For the For the For the |
isBasicBlock |
If the shape information and tiling strategy of both srcTensor and dstTensor meet the base block requirements, this parameter can be enabled to improve performance. By default, this parameter is disabled. For base blocks, the shapes of srcTensor and dstTensor must meet the following requirements:
|
Parameter |
Input/Output |
Description |
|---|---|---|
dstLocal |
Output |
Destination operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The shape of dstLocal must be the same as that of the source operand srcLocal. |
srcLocal |
Input |
Source operand. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. If the shape is [B, S, H], the length of the last axis (H axis) must be 32-byte aligned. |
gammaLocal |
Input |
Scaling coefficient. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. The length of its shape must be the same as that of the last axis (H axis) of srcLocal and dstLocal, that is, the shape is [H]. |
sharedTmpBuffer |
Input |
Temporary space. The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT. This parameter is used to store intermediate variables during complex internal API computation and is provided by developers. For details about how to obtain the temporary space size (BufferSize), see RmsNorm Tiling. |
epsilon |
Input |
Weight coefficient for preventing division by zero. The data type must be the same as that of srcLocal/dstLocal. |
tiling |
Input |
Tiling information required for RmsNorm computation. For details about how to obtain the tiling information, see RmsNorm Tiling. |
Returns
None
Restrictions
- The tensor space of dstLocal and gammaLocal cannot be reused.
- Currently, only the ND format is supported.
- For details about the operand address alignment requirements, see General Address Alignment Restrictions.
- If the H axis in the original shape of srcLocal is not 32-byte aligned, you need to pad the original input on the H axis to ensure 32-byte alignment. The computation result of the API will overwrite the data in dstLocal corresponding to the padded position of srcLocal.
Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | #include "kernel_operator.h" inline __aicore__ uint32_t AlignToBlock(const uint32_t inputValue, const uint32_t typeSize) { constexpr uint32_t ONE_BLK_SIZE = 32; uint32_t alignUnit = ONE_BLK_SIZE / typeSize; return (inputValue + alignUnit - 1) / alignUnit * alignUnit; } template <typename dataType, bool isBasicBlock = false> class KernelRmsNorm { public: __aicore__ inline KernelRmsNorm() {} __aicore__ inline void Init( GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, const RmsNormCustomTiling &customTiling) { tiling = customTiling.tiling; const uint32_t bLength = tiling.bLength; const uint32_t sLength = tiling.sLength; hLength = tiling.hLength; bshLength = bLength * sLength * hLength; constexpr uint32_t typeSize = sizeof(dataType); const uint32_t bsLength = AlignToBlock(bLength * sLength, typeSize); const uint32_t tmpBufferSize = bshLength * 2 + bsLength; epsilon = customTiling.epsilon; inputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(inputGm), bshLength); gammaGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(gammaGm), hLength); outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(outputGm), bshLength); pipe.InitBuffer(inQueue, 1, bshLength * typeSize); pipe.InitBuffer(inQueueGamma, 1, hLength * typeSize); pipe.InitBuffer(outQueue, 1, bshLength * typeSize); pipe.InitBuffer(tmpQueue, 1, tmpBufferSize); } __aicore__ inline void Process() { CopyIn(); Compute(); CopyOut(); } private: __aicore__ inline void CopyIn() { AscendC::LocalTensor<dataType> inputLocal = inQueue.AllocTensor<dataType>(); AscendC::DataCopy(inputLocal, inputGlobal, bshLength); inQueue.EnQue(inputLocal); AscendC::LocalTensor<dataType> gammaLocal = inQueueGamma.AllocTensor<dataType>(); AscendC::DataCopy(gammaLocal, gammaGlobal, hLength); inQueueGamma.EnQue(gammaLocal); } __aicore__ inline void Compute() { AscendC::LocalTensor<dataType> inputLocal = inQueue.DeQue<dataType>(); AscendC::LocalTensor<dataType> gammaLocal = inQueueGamma.DeQue<dataType>(); AscendC::LocalTensor<dataType> outputLocal = outQueue.AllocTensor<dataType>(); AscendC::LocalTensor<uint8_t> stackBuffer = tmpQueue.AllocTensor<uint8_t>(); AscendC::RmsNorm<dataType, isBasicBlock>(outputLocal, inputLocal, gammaLocal, stackBuffer, epsilon, tiling); inQueue.FreeTensor(inputLocal); inQueueGamma.FreeTensor(gammaLocal); tmpQueue.FreeTensor(stackBuffer); outQueue.EnQue(outputLocal); } __aicore__ inline void CopyOut() { AscendC::LocalTensor<dataType> outputLocal = outQueue.DeQue<dataType>(); AscendC::DataCopy(outputGlobal, outputLocal, bshLength); outQueue.FreeTensor(outputLocal); } private: AscendC::GlobalTensor<dataType> inputGlobal; AscendC::GlobalTensor<dataType> gammaGlobal; AscendC::GlobalTensor<dataType> outputGlobal; AscendC::TPipe pipe; AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueue; AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueGamma; AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue; AscendC::TQue<AscendC::TPosition::VECCALC, 1> tmpQueue; RmsNormTiling tiling; uint32_t hLength; dataType epsilon; uint32_t bshLength; }; template <typename dataType, bool isBasicBlock = false> __aicore__ inline void kernel_rmsnorm_operator(GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, GM_ADDR tiling) { GET_TILING_DATA(customTilingData, tiling) KernelRmsNorm<dataType, isBasicBlock> op; op.Init(inputGm, gammaGm, outputGm, customTilingData); op.Process(); } |