RmsNorm

Function Usage

Normalizes input data whose shape is [B, S, H] using RmsNorm. The formula is as follows:

γ is the scaling coefficient, and ε is the weight coefficient for preventing division by zero.

Prototype

  • Pass the temporary space through the sharedTmpBuffer input parameter.
    1
    2
    template <typename T, bool isBasicBlock = false>
    __aicore__ inline void RmsNorm(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<T>& gammaLocal, const LocalTensor<uint8_t>& sharedTmpBuffer, const T epsilon, const RmsNormTiling& tiling)
    
  • Allocate the temporary space through the API framework.
    1
    2
    template <typename T, bool isBasicBlock = false>
    __aicore__ inline void RmsNorm(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<T>& gammaLocal, const T epsilon, const RmsNormTiling& tiling)
    

Due to the complex computation involved in the internal implementation of this API, additional temporary space is required to store intermediate variables generated during computation. The temporary space can be allocated through the API framework or passed by developers through the sharedTmpBuffer input parameter.

  • When the API framework is used for temporary space allocation, developers do not need to allocate the space, but must reserve the required size for the space.
  • When the sharedTmpBuffer input parameter is used for passing the temporary space, the tensor serves as the temporary space. In this case, the API framework is not required for temporary space allocation. This enables developers to manage the sharedTmpBuffer space and reuse the buffer after calling the API, so that the buffer is not repeatedly allocated and deallocated, improving the flexibility and buffer utilization.

If the API framework is used, developers must reserve the temporary space. If sharedTmpBuffer is used, developers must allocate space for the tensor. The method of obtaining the temporary space size (BufferSize) is as follows: Obtain the required maximum and minimum temporary space sizes using the GetRmsNormMaxMinTmpSize API provided in RmsNorm Tiling. The minimum space can ensure correct functionality, while the maximum space is used to improve performance.

Parameters

Table 1 Parameters in the template

Parameter

Description

T

Data type of the operand.

isBasicBlock

If the shape information and the tiling policy of both srcTensor and dstTensor meet the basic block requirements, this parameter can be enabled to improve performance. By default, this parameter is disabled. For basic blocks, the shapes of srcTensor and dstTensor must meet the following requirements:

  • The length of the last axis (H axis) is a multiple of 64 but less than 2048.
  • The length (B*S) of a non-last axis is a multiple of 8.
Table 2 API parameters

Parameter

Input/Output

Description

dstLocal

Output

Destination operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The shape of dstLocal must be the same as that of the source operand srcLocal.

srcLocal

Input

Source operand.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

If the shape is [B, S, H], the length of the last axis (H axis) must be 32-byte aligned.

gammaLocal

Input

Scaling coefficient.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

The length of its shape must be the same as that of the last axis (H axis) of srcLocal and dstLocal, that is, the shape is [H].

sharedTmpBuffer

Input

Temporary space.

The type is LocalTensor, and the supported TPosition is VECIN, VECCALC, or VECOUT.

This parameter is used to store intermediate variables during complex internal API computation and is provided by developers.

For details about how to obtain the temporary space size (BufferSize), see RmsNorm Tiling.

epsilon

Input

Weight coefficient for preventing division by zero. The data type must be the same as that of srcLocal/dstLocal.

tiling

Input

Tiling information required for RmsNorm computation. For details about how to obtain the tiling information, see RmsNorm Tiling.

Returns

None

Availability

Precautions

  • The tensor space of srcLocal and dstLocal can be reused.
  • Currently, only the ND format is supported.
  • For details about the alignment requirements of the operand address offset, see General Restrictions.

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "kernel_operator.h"

inline __aicore__ uint32_t AlignToBlock(const uint32_t inputValue, const uint32_t typeSize)
{
    constexpr uint32_t ONE_BLK_SIZE = 32;
    uint32_t alignUnit = ONE_BLK_SIZE / typeSize;
    return (inputValue + alignUnit - 1) / alignUnit * alignUnit;
}

template <typename dataType, bool isBasicBlock = false>
class KernelRmsNorm {
public:
    __aicore__ inline KernelRmsNorm()
    {}
    __aicore__ inline void Init(
        GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, const RmsNormCustomTiling &customTiling)
    {
        tiling = customTiling.tiling;
        const uint32_t bLength = tiling.bLength;
        const uint32_t sLength = tiling.sLength;
        hLength = tiling.hLength;
        bshLength = bLength * sLength * hLength;
        constexpr uint32_t typeSize = sizeof(dataType);
        const uint32_t bsLength = AlignToBlock(bLength * sLength, typeSize);
        const uint32_t tmpBufferSize = bshLength * 2 + bsLength;
        epsilon = customTiling.epsilon;
        inputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(inputGm), bshLength);
        gammaGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(gammaGm), hLength);
        outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(outputGm), bshLength);
        pipe.InitBuffer(inQueue, 1, bshLength * typeSize);
        pipe.InitBuffer(inQueueGamma, 1, hLength * typeSize);
        pipe.InitBuffer(outQueue, 1, bshLength * typeSize);
        pipe.InitBuffer(tmpQueue, 1, tmpBufferSize);
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<dataType> inputLocal = inQueue.AllocTensor<dataType>();
        AscendC::DataCopy(inputLocal, inputGlobal, bshLength);
        inQueue.EnQue(inputLocal);
        AscendC::LocalTensor<dataType> gammaLocal = inQueueGamma.AllocTensor<dataType>();
        AscendC::DataCopy(gammaLocal, gammaGlobal, hLength);
        inQueueGamma.EnQue(gammaLocal);
    }

    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<dataType> inputLocal = inQueue.DeQue<dataType>();
        AscendC::LocalTensor<dataType> gammaLocal = inQueueGamma.DeQue<dataType>();
        AscendC::LocalTensor<dataType> outputLocal = outQueue.AllocTensor<dataType>();
        AscendC::LocalTensor<dataType> stackBuffer = tmpQueue.AllocTensor<uint8_t>();
        AscendC::RmsNorm<dataType, isBasicBlock>(outputLocal, inputLocal, gammaLocal, stackBuffer, epsilon, tiling);
        inQueue.FreeTensor(inputLocal);
        inQueueGamma.FreeTensor(gammaLocal);
        tmpQueue.FreeTensor(stackBuffer);
        outQueue.EnQue(outputLocal);
    }

    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<dataType> outputLocal = outQueue.DeQue<dataType>();
        AscendC::DataCopy(outputGlobal, outputLocal, bshLength);
        outQueue.FreeTensor(outputLocal);
    }

private:
    AscendC::GlobalTensor<dataType> inputGlobal;
    AscendC::GlobalTensor<dataType> gammaGlobal;
    AscendC::GlobalTensor<dataType> outputGlobal;
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueGamma;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueue;
    AscendC::TQue<AscendC::QuePosition::VECCALC, 1> tmpQueue;
    RmsNormTiling tiling;
    uint32_t hLength;
    dataType epsilon;
    uint32_t bshLength;
};

template <typename dataType, bool isBasicBlock = false>
__aicore__ inline void kernel_rmsnorm_operator(GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, GM_ADDR tiling)
{
    GET_TILING_DATA(customTilingData, tiling)
    KernelRmsNorm<dataType, isBasicBlock> op;
    op.Init(inputGm, gammaGm, outputGm, customTilingData);
    op.Process();
}