RmsNorm
功能说明
实现对shape大小为[B,S,H]的输入数据的RmsNorm归一化,其计算公式如下:

其中,γ为缩放系数,ε为防除零的权重系数。
函数原型
- 通过sharedTmpBuffer入参传入临时空间1 2 template <typename T, bool isBasicBlock = false> __aicore__ inline void RmsNorm(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<T>& gammaLocal, const LocalTensor<uint8_t>& sharedTmpBuffer, const T epsilon, const RmsNormTiling& tiling) 
- 接口框架申请临时空间1 2 template <typename T, bool isBasicBlock = false> __aicore__ inline void RmsNorm(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const LocalTensor<T>& gammaLocal, const T epsilon, const RmsNormTiling& tiling) 
由于该接口的内部实现中涉及复杂的计算,需要额外的临时空间来存储计算过程中的中间变量。临时空间支持接口框架申请和开发者通过sharedTmpBuffer入参传入两种方式。
- 接口框架申请临时空间,开发者无需申请,但是需要预留临时空间的大小。
- 通过sharedTmpBuffer入参传入,使用该tensor作为临时空间进行处理,接口框架不再申请。该方式开发者可以自行管理sharedTmpBuffer内存空间,并在接口调用完成后,复用该部分内存,内存不会反复申请释放,灵活性较高,内存利用率也较高。
接口框架申请的方式,开发者需要预留临时空间;通过sharedTmpBuffer传入的情况,开发者需要为tensor申请空间。临时空间大小BufferSize的获取方式如下:通过15.27.2-RmsNorm Tiling中提供的GetRmsNormMaxMinTmpSize接口获取所需最大和最小临时空间大小,最小空间可以保证功能正确,最大空间用于提升性能。
参数说明
| 参数名 | 描述 | 
|---|---|
| T | 操作数的数据类型。 | 
| isBasicBlock | srcTensor和dstTensor的shape信息和Tiling切分策略满足基本块要求的情况下,可以使能该参数用于提升性能,默认不使能。基本块要求srcTensor和dstTensor的shape需要满足如下条件: 
 | 
| 参数名 | 输入/输出 | 描述 | 
|---|---|---|
| dstLocal | 输出 | 目的操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float dstLocal的shape和源操作数srcLocal需要保持一致。 | 
| srcLocal | 输入 | 源操作数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float shape为[B,S,H],尾轴H长度需要满足32字节对齐。 | 
| gammaLocal | 输入 | 缩放系数。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float shape需要与srcLocal和dstLocal的尾轴H长度相等,即shape为[H]。 | 
| sharedTmpBuffer | 输入 | 临时空间。 类型为LocalTensor,支持的TPosition为VECIN/VECCALC/VECOUT。 Atlas A2训练系列产品/Atlas 800I A2推理产品,支持的数据类型为:half/float Atlas推理系列产品AI Core,支持的数据类型为:half/float 接口内部复杂计算时用于存储中间变量,由开发者提供。 临时空间大小BufferSize的获取方式请参考RmsNorm Tiling。 | 
| epsilon | 输入 | 防除零的权重系数,数据类型需要与srcLocal/dstLocal保持一致。 | 
| tiling | 输入 | RmsNorm计算所需Tiling信息,Tiling信息的获取请参考RmsNorm Tiling。 | 
返回值
无
支持的型号
Atlas A2训练系列产品/Atlas 800I A2推理产品
Atlas推理系列产品AI Core
注意事项
- srcLocal和dstLocal的Tensor空间可以复用。
- 当前仅支持ND格式的输入,不支持其他格式。
- 操作数地址偏移对齐要求请参见通用约束。
调用示例
#include "kernel_operator.h"
inline __aicore__ uint32_t AlignToBlock(const uint32_t inputValue, const uint32_t typeSize)
{
    constexpr uint32_t ONE_BLK_SIZE = 32;
    uint32_t alignUnit = ONE_BLK_SIZE / typeSize;
    return (inputValue + alignUnit - 1) / alignUnit * alignUnit;
}
template <typename dataType, bool isBasicBlock = false>
class KernelRmsNorm {
public:
    __aicore__ inline KernelRmsNorm()
    {}
    __aicore__ inline void Init(
        GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, const RmsNormCustomTiling &customTiling)
    {
        tiling = customTiling.tiling;
        const uint32_t bLength = tiling.bLength;
        const uint32_t sLength = tiling.sLength;
        hLength = tiling.hLength;
        bshLength = bLength * sLength * hLength;
        constexpr uint32_t typeSize = sizeof(dataType);
        const uint32_t bsLength = AlignToBlock(bLength * sLength, typeSize);
        const uint32_t tmpBufferSize = bshLength * 2 + bsLength;
        epsilon = customTiling.epsilon;
        inputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(inputGm), bshLength);
        gammaGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(gammaGm), hLength);
        outputGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ dataType *>(outputGm), bshLength);
        pipe.InitBuffer(inQueue, 1, bshLength * typeSize);
        pipe.InitBuffer(inQueueGamma, 1, hLength * typeSize);
        pipe.InitBuffer(outQueue, 1, bshLength * typeSize);
        pipe.InitBuffer(tmpQueue, 1, tmpBufferSize);
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }
private:
    __aicore__ inline void CopyIn()
    {
        AscendC::LocalTensor<dataType> inputLocal = inQueue.AllocTensor<dataType>();
        AscendC::DataCopy(inputLocal, inputGlobal, bshLength);
        inQueue.EnQue(inputLocal);
        AscendC::LocalTensor<dataType> gammaLocal = inQueueGamma.AllocTensor<dataType>();
        AscendC::DataCopy(gammaLocal, gammaGlobal, hLength);
        inQueueGamma.EnQue(gammaLocal);
    }
    __aicore__ inline void Compute()
    {
        AscendC::LocalTensor<dataType> inputLocal = inQueue.DeQue<dataType>();
        AscendC::LocalTensor<dataType> gammaLocal = inQueueGamma.DeQue<dataType>();
        AscendC::LocalTensor<dataType> outputLocal = outQueue.AllocTensor<dataType>();
        AscendC::LocalTensor<dataType> stackBuffer = tmpQueue.AllocTensor<uint8_t>();
        AscendC::RmsNorm<dataType, isBasicBlock>(outputLocal, inputLocal, gammaLocal, stackBuffer, epsilon, tiling);
        inQueue.FreeTensor(inputLocal);
        inQueueGamma.FreeTensor(gammaLocal);
        tmpQueue.FreeTensor(stackBuffer);
        outQueue.EnQue(outputLocal);
    }
    __aicore__ inline void CopyOut()
    {
        AscendC::LocalTensor<dataType> outputLocal = outQueue.DeQue<dataType>();
        AscendC::DataCopy(outputGlobal, outputLocal, bshLength);
        outQueue.FreeTensor(outputLocal);
    }
private:
    AscendC::GlobalTensor<dataType> inputGlobal;
    AscendC::GlobalTensor<dataType> gammaGlobal;
    AscendC::GlobalTensor<dataType> outputGlobal;
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueue;
    AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueGamma;
    AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueue;
    AscendC::TQue<AscendC::QuePosition::VECCALC, 1> tmpQueue;
    RmsNormTiling tiling;
    uint32_t hLength;
    dataType epsilon;
    uint32_t bshLength;
};
template <typename dataType, bool isBasicBlock = false>
__aicore__ inline void kernel_rmsnorm_operator(GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, GM_ADDR tiling)
{
    GET_TILING_DATA(customTilingData, tiling)
    KernelRmsNorm<dataType, isBasicBlock> op;
    op.Init(inputGm, gammaGm, outputGm, customTilingData);
    op.Process();
}