RmsNorm Tiling
Function Usage
Ascend C provides the RmsNorm Tiling API for users to obtain the tiling parameter required for RmsNorm kernel computation.
To obtain the tiling parameter, perform the following two steps:
- Use the GetRmsNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required for RmsNorm computation.To perform RmsNorm computation in the kernel, developers need to reserve or allocate the temporary space. This API is used to obtain the maximum and minimum sizes of the temporary space to be reserved or allocated on the host. Developers can select a proper size within this range as the tiling parameter and pass it to the kernel.
- To ensure correct functions, the temporary space to be reserved or allocated cannot be less than the minimum temporary space.
- Within the range between the minimum and maximum, as the temporary space increases, the API computing performance in the kernel can be optimized to some extent. To achieve better performance, reserve or allocate the space based on the actual buffer usage.
- Use the GetRmsNormTilingInfo API to obtain the tiling parameter required by the RmsNorm kernel API.
The definition of the RmsNormTiling structure is as follows. You do not need to pay attention to the specific information of this tiling structure, but only need to pass it to the kernel and directly use it through RmsNorm high-level APIs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14
struct RmsNormTiling { uint32_t bLength = 0; uint32_t sLength = 0; uint32_t hLength = 0; uint32_t originalHLength = 0; float reciprocalOfHLength = 0; uint32_t mainBshLength = 0; uint32_t mainBsLength = 0; uint32_t mainBsLengthAlign = 0; uint32_t loopRound = 0; uint32_t inputTailPos = 0; uint32_t tailBshLength = 0; uint32_t tailBsLength = 0; };
Prototype
1 | bool GetRmsNormMaxMinTmpSize(const ge::Shape& srcShape, const uint32_t typeSize, uint32_t& maxValue, uint32_t& minValue, const bool isBasicBlock = false) |
1 | bool GetRmsNormTilingInfo(const ge::Shape& srcShape, const ge::Shape& originSrcShape, const uint32_t stackBufferByteSize, const uint32_t typeSize, optiling::RmsNormTiling& tiling, const bool isBasicBlock = false) |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
srcShape |
Input |
Input shape. |
typeSize |
Input |
Data type size of operator inputs. The unit is byte. For example, if the input data type of the operator is half, set this parameter to 2. |
maxValue |
Output |
Maximum size of the temporary space required by RmsNorm computation. Any space exceeding this value will not be utilized by the API. Within the range between the minimum and maximum, as the temporary space increases, the API computing performance in the kernel can be optimized to some extent. To achieve better performance, reserve or allocate the space based on the actual buffer usage. If the maximum space size is 0, no temporary space is required. NOTE:
maxValue is for reference only and may be larger than the remaining space of the Unified Buffer. In this case, select a proper temporary space size based on the remaining space of the Unified Buffer. |
minValue |
Output |
Minimum size of the temporary space required by RmsNorm computation. To ensure correct functions, the size of the temporary space to be reserved or allocated during API computation cannot be less than the value of this parameter. If the minimum space size is 0, no temporary space is required. |
isBasicBlock |
Input |
Whether to enable basic block computation. The value must be consistent with that of the API in the kernel. The default value is false. |
Parameter |
Input/Output |
Description |
|---|---|---|
srcShape |
Input |
Shape of the input tensor, which has been aligned upwards to 32 bytes along the H axis. Ensure that the B/S of srcShape is the same as that of originSrcShape. |
originSrcShape |
Input |
Original input shape. |
stackBufferByteSize |
Input |
Size of the remaining space that can be used for RmsNorm computation. The unit is byte. Use the GetRmsNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes, so that developers can select a proper size within this range and pass it as stackBufferByteSize. |
typeSize |
Input |
Data type size of operator inputs. The unit is byte. For example, if the input data type of the operator is half, set this parameter to 2. |
tiling |
Output |
Tiling information required for RmsNorm computation. |
isBasicBlock |
Input |
Whether to enable basic block computation. The value must be consistent with that of the API in the kernel. The default value is false. If basic block computation is enabled, ensure that the H axis of originSrcShape is also 32-byte aligned. |
Returns
- The return value of GetRmsNormMaxMinTmpSize is either true or false. true indicates that the maximum and minimum temporary space sizes required for RmsNorm internal computation are successfully obtained. false indicates that the sizes fail to be obtained. In this case, check whether the input shape meets the specified requirements.
- The return value of GetRmsNormTilingInfo is either true or false. true indicates that the tiling parameter values of RmsNorm are successfully obtained. false indicates that the values fail to be obtained. In this case, check whether the input stackBufferByteSize meets the minimum temporary space requirements. If isBasicBlock is enabled, check whether the input shape meets the basic block requirements.
Example
- Add the RmsNormTiling structure parameter to the TilingData structure to function as a field.
1 2 3 4 5 6 7
BEGIN_TILING_DATA_DEF(RmsnormCustomTilingData) // Register a tiling class and uses the tiling name as the input parameter. TILING_DATA_FIELD_DEF(uint32_t, totalLength); // Add the tiling field to compute the total data volume. TILING_DATA_FIELD_DEF(uint32_t, tileNum); // Add the tiling field that specifies the total number of data blocks to be computed on each core. TILING_DATA_FIELD_DEF(uint32_t, tmpBufSize); // Add a tiling field to specify the size of the temporary space. ... // Add other tiling fields. TILING_DATA_FIELD_DEF_STRUCT(RmsNormTiling, rmsnormTilingData); // Add the RmsNormTiling structure parameter to the TilingData structure. END_TILING_DATA_DEF;
- The tiling implementation function first calls the GetRmsNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required by the RmsNorm API to complete computation, sets an appropriate space size based on this range and the actual buffer usage, and then obtains the tiling parameter required by the RmsNorm kernel API based on the input shape and remaining size of computing space.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
namespace optiling { const uint32_t BLOCK_DIM = 8; const uint32_t TILE_NUM = 8; static ge::graphStatus TilingFunc(gert::TilingContext* context) { RmsNormCustomTilingData tiling; uint32_t totalLength = context->GetInputTensor(0)->GetShapeSize(); context->SetBlockDim(BLOCK_DIM); tiling.set_totalLength(totalLength); tiling.set_tileNum(TILE_NUM); // Set other tiling parameters. ... std::vector<int64_t> shapeVec = {2, 16, 64}; ge::Shape srcShape(shapeVec); std::vector<int64_t> oriShapeVec = {2, 16, 64}; ge::Shape oriSrcShape(oriShapeVec); // This example is for reference only. Use GetRmsNormMaxMinTmpSize to obtain the minimum value and pass it to ensure correct functionality. Developers can pass a proper space size as required. uint32_t minValue = 0; uint32_t maxValue = 0; AscendC::GetRmsNormMaxMinTmpSize(srcShape, sizeof(half), maxValue, minValue, isBasicBlock); tiling.set_tmpBufSize(minValue); // Obtain the RmsNorm tiling parameter. AscendC::GetRmsNormTilingInfo(srcShape, oriSrcShape, minValue , sizeof(half), false, tiling.rmsnormTilingData); ... // Other logic tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); context->SetTilingKey(1); return ge::GRAPH_SUCCESS; } } // namespace optiling
- The kernel calls GET_TILING_DATA in the kernel function to obtain TilingData, and then passes the RmsNorm Tiling information in TilingData to the RmsNorm API for computation. For details about the complete example on the kernel, see RmsNorm.
1 2 3 4 5 6 7 8 9
extern "C" __global__ __aicore__ void rmsnorm_custom(GM_ADDR inputGm, GM_ADDR gammaGm, GM_ADDR outputGm, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); KernelRmsNorm op; op.Init(inputGm, gammaGm, outputGm, tilingData.totalLength, tilingData.tileNum, tilingData.rmsnormTilingData); if (TILING_KEY_IS(1)) { op.Process(); } }