GroupNorm Tiling
Function
Obtains the tiling parameters required for GroupNorm kernel computation. To obtain the tiling parameters, perform the following two steps:
- Use the GetGroupNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required for GroupNorm computation.To perform GroupNorm computation in the kernel, you need to reserve or allocate the temporary space. GetGroupNormMaxMinTmpSize is used to obtain the maximum and minimum temporary space sizes to be reserved or allocated on the host. You can select proper sizes within this range as the tiling parameters and pass them to the kernel.
- To ensure correct functions, the temporary space to be reserved or allocated cannot be less than the minimum temporary space.
- Within the range between the minimum and maximum, as the temporary space increases, the API computing performance in the kernel can be optimized to some extent. For better performance, reserve or allocate the temporary space based on the actual memory usage.
- Use the GetGroupNormNDTilingInfo API to obtain the tiling parameters required by the GroupNorm kernel API.
Below is the definition of the GroupNorm tiling structure. You do not need to pay attention to the specific information of this tiling structure. They only need to pass it to the kernel and directly use it through GroupNorm high-level APIs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
struct GroupNormTiling { uint32_t n = 0; uint32_t c = 0; uint32_t hw = 0; uint32_t g = 0; uint32_t d = 0; uint32_t hwAlignSize = 0; uint32_t dhwAlignSize = 0; uint32_t inputXSize = 0; uint32_t meanVarSize = 0; uint32_t numberOfTmpBuf = 0; uint32_t meanTmpTensorPos = 0; uint32_t meanTmpTensorSize = 0; uint32_t varianceTmpTensorPos = 0; uint32_t varianceTmpTensorSize = 0; uint32_t tmpBufSize = 0; uint32_t oneTmpSize = 0; uint32_t firstTmpStartPos = 0; uint32_t secondTmpStartPos = 0; uint32_t thirdTmpStartPos = 0; uint32_t loopRound = 0; uint32_t inputRoundSize = 0; uint32_t inputTailSize = 0; uint32_t inputTailPos = 0; uint32_t meanVarRoundSize = 0; uint32_t meanVarTailSize = 0; uint32_t meanVarTailPos = 0; uint32_t bshCurLength = 0; uint32_t bsCurLength = 0; float factor = 0; bool smallShape = 0; };
Prototype
1 | void GetGroupNormMaxMinTmpSize(const ge::Shape& srcShape, const uint32_t typeSize, const bool isReuseSource, const uint32_t groupNum, uint32_t& maxValue, uint32_t& minValue) |
1 | void GetGroupNormNDTilingInfo(const ge::Shape& srcShape, const uint32_t stackBufferSize, const uint32_t typeSize, const bool isReuseSource, const uint32_t groupNum, optiling::GroupNormTiling& tiling) |
1 | void GetGroupNormNDTilingInfo(const ge::Shape& srcShape, const uint32_t stackBufferSize, const uint32_t typeSize, const bool isReuseSource, const uint32_t groupNum, AscendC::tiling::GroupNormTiling& tiling) |
Parameters
Parameter |
Input/Output |
Description |
|---|---|---|
srcShape |
Input |
Shape information [N, C, H, W] of the input data inputX. |
typeSize |
Input |
Data type size of the input data inputX. The unit is byte. For example, if the input data type is half, set this parameter to 2. |
isReuseSource |
Input |
Whether intermediate variables can reuse the input buffer. This parameter is reserved. Pass the default value false. |
groupNum |
Input |
Number of groups in the C dimension. |
maxValue |
Output |
Tiling information (maximum temporary space size) required by the GroupNorm API. Maximum size of the temporary space required by GroupNorm computation. Any space exceeding this value will not be utilized by the API. Within the range between the minimum and maximum, as the temporary space increases, the API computing performance in the kernel can be optimized to some extent. For better performance, reserve or allocate the temporary space based on the actual memory usage. NOTE:
maxValue is for reference only and may be larger than the remaining space of the Unified Buffer. In this case, select a proper temporary space size based on the remaining space of the Unified Buffer. |
minValue |
Output |
Tiling information (minimum temporary space size) required by the GroupNorm API. Minimum size of the temporary space required by GroupNorm computation. To ensure correct functions, the temporary space to be reserved or allocated during API computation cannot be less than the value of this parameter. |
Parameter |
Input/Output |
Description |
|---|---|---|
srcShape |
Input |
Shape information [N, C, H, W] of the input data inputX. |
stackBufferSize |
Input |
Size of the space that can be used by the GroupNorm API. The unit is byte. |
typeSize |
Input |
Size of the input data type, in bytes. For example, if the input data type is half, set this parameter to 2. |
isReuseSource |
Input |
Whether the buffer space of inputX can be reused. |
groupNum |
Input |
Number of groups in the C dimension. |
tiling |
Output |
Tiling information of input data. |
Returns
None
Restrictions
None
Example
The following example describes the process of obtaining the tiling parameters on the host and the method of using the parameter in the kernel. In this example, the shape size of the input tensor is [2, 16, 8, 8], and the input data type is half.
- Add the GroupNormTiling structure parameter to the TilingData structure to function as a field.
1 2 3 4 5 6 7 8 9 10
BEGIN_TILING_DATA_DEF(TilingData) // Register a tiling class and use the tiling name as the input parameter. TILING_DATA_FIELD_DEF(uint32_t, n); TILING_DATA_FIELD_DEF(uint32_t, c); TILING_DATA_FIELD_DEF(uint32_t, h); TILING_DATA_FIELD_DEF(uint32_t, w); TILING_DATA_FIELD_DEF(uint32_t, group); // Add other tiling fields. ... TILING_DATA_FIELD_DEF_STRUCT(GroupNormTiling, GroupNormTilingData); // Add the GroupNormTiling structure parameter to the TilingData structure. END_TILING_DATA_DEF;
- The tiling implementation function first calls the GetGroupNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required by the GroupNorm API to complete computation, sets an appropriate space size based on this range and the actual buffer usage, and then obtains the tiling parameters required by the GroupNorm kernel API based on the input shape and remaining size of computing space.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
namespace optiling { const uint32_t BLOCK_DIM = 8; const uint32_t TILE_NUM = 8; static ge::graphStatus TilingFunc(gert::TilingContext* context) { TilingData tiling; uint32_t totalLength = context->GetInputTensor(0)->GetShapeSize(); context->SetBlockDim(BLOCK_DIM); tiling.set_tileNum(TILE_NUM); // Set other tiling parameters. ... std::vector<int64_t> shapeVec = {2, 16, 8, 8}; // {n, c, h, w} ge::Shape srcShape(shapeVec); uint32_t groupNum=4 uint32_t minSize = 0; uint32_t maxSize = 0; // This example is for reference only. You can use the GetGroupNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required by the GroupNorm API to complete computation. You can reserve or allocate the space based on the actual buffer usage. AscendC::GetGroupNormMaxMinTmpSize(srcShape, sizeof(half), false, groupNum, maxSize, minSize); // Obtain the GroupNorm tiling parameters. AscendC::GetGroupNormNDTilingInfo(srcShape, maxSize, sizeof(half), false, groupNum, tiling.groupNormTilingData); ... // Other logic tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); context->SetTilingKey(1); return ge::GRAPH_SUCCESS; } } // namespace optiling
- The kernel calls GET_TILING_DATA in the kernel function to obtain tilingData, and then passes the GroupNorm tiling information in tilingData to the GroupNorm APIs for computation.
1 2 3 4 5 6 7 8 9
extern "C" __global__ __aicore__ void groupnorm_custom(GM_ADDR inputX_gm, GM_ADDR gamm_gm, GM_ADDR beta_gm, GM_ADDR output_gm, GM_ADDR outputMean_gm, GM_ADDR outputVariance_gm, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); KernelGroupNorm<half, false> op; op.Init(inputX_gm, gamm_gm, beta_gm, output_gm, outputMean_gm, outputVariance_gm, tilingData.groupNormTilingData); if (TILING_KEY_IS(1)) { op.Process(); } }