BatchNorm Tiling

Function Usage

The BatchNorm Tiling API is used to obtain the tiling parameter required for BatchNorm kernel computation. To obtain the tiling parameter, perform the following two steps:

  1. Use the GetBatchNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required for BatchNorm computation.
    To perform BatchNorm computation in the kernel, developers need to reserve or allocate the temporary space. This API is used to obtain the maximum and minimum sizes of the temporary space to be reserved or allocated on the host. Developers can select a proper size within this range as the tiling parameter and pass it to the kernel.
    • To ensure correct functions, the temporary space to be reserved or allocated cannot be less than the minimum temporary space.
    • Within the range between the minimum and maximum, as the temporary space increases, the API computing performance in the kernel can be optimized to some extent. To achieve better performance, reserve or allocate the space based on the actual buffer usage.
  2. Use the GetBatchNormNDTilingInfo API to obtain the tiling parameter required by the BatchNorm kernel API.
    The definition of the BatchNormTiling structure is as follows. You do not need to pay attention to the specific information of this tiling structure, but only need to pass it to the kernel and directly use it through BatchNorm high-level APIs.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    struct BatchNormTiling {
        uint32_t originalBLength = 0;
        uint32_t meanVarSize = 0;
        uint32_t meanTmpTensorPos = 0;
        uint32_t varianceTmpTensorPos = 0;
        uint32_t tmpBufSize = 0;
        uint32_t oneTmpSize = 0;
        uint32_t firstTmpStartPos = 0;
        uint32_t secondTmpStartPos = 0;
        uint32_t thirdTmpStartPos = 0;
        uint32_t loopRound = 0;
        uint32_t inputTailSize = 0;
        uint32_t inputTailPos = 0;
        uint32_t meanVarTailSize = 0;
        uint32_t meanVarTailPos = 0;
        uint32_t bshCurLength = 0;
        uint32_t shCurLength = 0;
        float firstDimValueBack = 0;
        uint32_t castHalfRepStride;
        uint32_t shCurLengthBlockNum;
        uint32_t castHalfOutRepStride;
        uint32_t basicLoop;
        uint32_t brcRepeatTimes;
        uint32_t oriBloop;
        uint32_t oriBTail;
        uint64_t oriBTmpLoopOffset;
        uint64_t oriBTmpTailOffset;
        uint64_t oriBOutLoopOffset;
        uint64_t oriBOutTailOffset;
        uint32_t reduceAddLoop;
        uint32_t reduceAddTail;
        uint64_t reduceAddTailOffset;
    };
    

Prototype

1
bool GetBatchNormMaxMinTmpSize(const ge::Shape& srcShape, const ge::Shape& originSrcShape, const uint32_t typeSize, const bool isReuseSource, uint32_t& maxValue,uint32_t& minValue, const bool isBasicBlock = false)
1
bool GetBatchNormNDTilingInfo(const ge::Shape& srcShape, const ge::Shape& originSrcShape, const uint32_t stackBufferByteSize, const uint32_t typeSize, const bool isReuseSource, optiling::BatchNormTiling& tilling, const bool isBasicBlock = false)
Table 1 GetBatchNormMaxMinTmpSize API parameters

Parameter

Input/Output

Description

srcShape

Input

Shape [B, S, H] of input data inputX. S*H must be 32-byte aligned.

originSrcShape

Input

Original shape [originB, originS, originH] of input data inputX.

typeSize

Input

Data type size of operator inputs. The unit is byte. For example, if the input data type of the operator is half, set this parameter to 2.

isReuseSource

Input

Whether intermediate variables can reuse the input buffer. This parameter is reserved. Pass the default value false.

maxValue

Output

Maximum size of the temporary space required by BatchNorm computation. Any space exceeding this maxValue will not be utilized by the API. Within the min-max range, the larger the reserved/allocated space, the better the API computing performance. To achieve better performance, reserve or allocate the space based on the actual buffer usage. If maxValue is set to 0, no temporary space is required.

NOTE:

maxValue is for reference only and may be larger than the remaining space of the Unified Buffer. In this case, select a proper temporary space size based on the remaining space of the Unified Buffer.

minValue

Output

Minimum size of the temporary space required by BatchNorm computation. To ensure correct functions, the size of the temporary space to be reserved or allocated during API computation cannot be less than the value of minValue. If the minimum space size is 0, no temporary space is required.

isBasicBlock

Input

Whether to enable the basic block, which must be the same as that of the BatchNorm API.

Table 2 GetBatchNormNDTillingInfo API parameters

Parameter

Input/Output

Meaning

srcShape

Input

Shape [B, S, H] of input data inputX. S*H must be 32-byte aligned.

originSrcShape

Input

Original shape [originB, originS, originH] of input data inputX.

stackBufferByteSize

Input

Size of the space that can be used by the BatchNorm API. The unit is byte.

typeSize

Input

Byte size of the input data type.

isReuseSource

Input

Whether intermediate variables can reuse the input buffer. This parameter is reserved. Pass the default value false.

tilling

Output

Tilling information of input data.

Returns

  • The return value of GetBatchNormMaxMinTmpSize is either true or false. true indicates that the maximum and minimum temporary space sizes required for BatchNorm internal computation are successfully obtained. false indicates that the sizes fail to be obtained.
  • The return value of GetBatchNormNDTilingInfo is either true or false. true indicates that the tiling parameter values of BatchNorm are successfully obtained. false indicates that the values fail to be obtained.

Example

The following example describes the process of obtaining the tiling parameter on the host and the method of using the parameter in the kernel. In this example, the shape size of the input tensor is [16, 16, 16], and the input data type is half.

  1. Add the BatchNormTiling structure parameter to the TilingData structure to function as a field.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    BEGIN_TILING_DATA_DEF(TilingData)               // Register a tiling class and uses the tiling name as the input parameter.
      TILING_DATA_FIELD_DEF(uint32_t, tileNum);     // Add the tiling field that specifies the total number of data blocks to be computed on each core.
      TILING_DATA_FIELD_DEF(uint32_t, bLength);     // Add the tiling field that specifies the length of the b dimension of the input shape.
      TILING_DATA_FIELD_DEF(uint32_t, sLength);     // Add the tiling field that specifies the length of the s dimension of the input shape.
      TILING_DATA_FIELD_DEF(uint32_t, hLength);     // Add the tiling field that specifies the length of the h dimension of the input shape.
      TILING_DATA_FIELD_DEF(uint32_t, originalBLength);     // Add the tiling field that specifies the length of the original b dimension of the input shape.
      ...                                           // Add other tiling fields.
      TILING_DATA_FIELD_DEF_STRUCT(BatchNormTiling, batchNormTilingData); // Add the BatchNormTiling structure parameter to the TilingData structure.
    END_TILING_DATA_DEF;
    
  2. The tiling implementation function first calls the GetBatchNormMaxMinTmpSize API to obtain the maximum and minimum temporary space sizes required by the BatchNorm API to complete computation, sets an appropriate space size based on this range and the actual buffer usage, and then obtains the tiling parameter required by the BatchNorm kernel API based on the input shape and remaining size of computing space.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    namespace optiling {
    const uint32_t BLOCK_DIM = 8;
    const uint32_t TILE_NUM = 8;
    static ge::graphStatus TilingFunc(gert::TilingContext* context)
    {
        TilingData tiling;
        uint32_t totalLength = context->GetInputTensor(0)->GetShapeSize();
        context->SetBlockDim(BLOCK_DIM);
        tiling.set_tileNum(TILE_NUM);
        // Set other tiling parameters.
        ... 
        std::vector<int64_t> shapeVec = {16, 16, 16};//{b,s,h}
        std::vector<int64_t> originShapeVec = {15, 16, 16};//{originB,originS,originH}
        ge::Shape srcShape(shapeVec);
        ge::Shape originSrcShape(originShapeVec);
        uint32_t minSize = 0;
        uint32_t maxSize = 0;
        // This example is for reference only. Use GetBatchNormMinTmpSize to obtain the minimum value and pass it to ensure correct functionality. Developers can pass a proper space size as required.
        AscendC::GetBatchNormMinTmpSize(srcShape, originSrcShape, sizeof(half), false, maxSize, minSize, false);
        // Obtain the BatchNorm tiling parameter.
        AscendC::GetBatchNormNDTillingInfo(srcShape, originSrcShape, minSize, sizeof(half), false, tiling.batchNormTilingData, false); 
         ... // Other logic
        tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
        context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
        context->SetTilingKey(1);
        return ge::GRAPH_SUCCESS;
    }
    } // namespace optiling
    
  3. The kernel calls GET_TILING_DATA in the kernel function to obtain TilingData, and then passes the BatchNormTiling information in TilingData to the BatchNorm API for computation. For details about the complete example on the kernel, see BatchNorm.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    extern "C" __global__ __aicore__ void func_custom(GM_ADDR inputX_gm, GM_ADDR gamm_gm, GM_ADDR beta_gm, GM_ADDR output_gm, GM_ADDR outputMean_gm, GM_ADDR outputVariance_gm, GM_ADDR tiling)
    {
        GET_TILING_DATA(tilingData, tiling);
        KernelBatchnorm<half, false, false> op;
        op.Init(inputX_gm, gamm_gm, beta_gm, output_gm, outputMean_gm, outputVariance_gm, tilingData.batchNormTilingData);
        if (TILING_KEY_IS(1)) {
            op.Process();
        }
    }