Tiling Template Programming

As described in section "TilingKey programming", TilingKey is difficult to remember and understand because it is usually a long number without clear meaning.

If multiple TilingKeys are involved, developers rely on the TilingKeys to manage the kernel implementation, which is complex in both management and usage. To simplify this process, use template programming instead of traditional TilingKey programming to reduce the dependency on TilingKey values and achieve more intuitive and efficient kernel management. The procedure is as follows:

  1. In the op_kernel folder of the custom operator project, add a header file that contains the template argument declaration and template argument selection. In this example, the header file is named tiling_key_add_custom.h.
    • Include the template header file ascendc/host_api/tiling/template_argument.h.
    • Define the template argument declaration ASCENDC_TPL_ARGS_DECL and template argument selection ASCENDC_TPL_ARGS_SEL (available template). For details about the API, see Template Argument Declaration.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    #include "ascendc/host_api/tiling/template_argument.h"
    
    // Template argument declaration
    ASCENDC_TPL_ARGS_DECL(AddCustomTemplate, // Operator OpType
    Definition of the template parameter of the ASCENDC_TPL_DATATYPE_DECL(D_T_X, C_DT_FLOAT, C_DT_FLOAT16, ASCENDC_TPL_INPUT(0)), // DataType type: data type of the input parameter x. The value can be float16 or float32. ASCENDC_TPL_INPUT(0) indicates the 0th input on the kernel side.
    Definition of the template parameter of the ASCENDC_TPL_DATATYPE_DECL(D_T_Y, C_DT_FLOAT, C_DT_FLOAT16, ASCENDC_TPL_INPUT(1)), // DataType type: data type of the input parameter y. The value can be float16 or float32. ASCENDC_TPL_INPUT(1) indicates the 1st input on the kernel side.
    Definition of the template parameter of the ASCENDC_TPL_DATATYPE_DECL(D_T_Z, C_DT_FLOAT, C_DT_FLOAT16, ASCENDC_TPL_OUTPUT(0)), // DataType type: data type of the input parameter z. The value can be float16 or float32. ASCENDC_TPL_OUTPUT(0) indicates the first output on the kernel side.
    ASCENDC_TPL_UINT_DECL(TILE_NUM, ASCENDC_TPL_8_BW, ASCENDC_TPL_UI_MIX, 2, 0, 2, 3, 5, 10, 12, 13, 9, 8),// Template argument declaration of the custom UINT type: The template arguments include the number of tiles, encoding bit width ASCENDC_TPL_8_BW (8 bits), which indicates that the number of template arguments does not exceed the 8-bit range, and ASCENDC_TPL_UI_MIX, which indicates that the value range is expressed in mixed mode. There are two groups of data {0–2} and {3–5} and exhaustive values 10, 12, 13, 9, and 8. The final result is {0, 1, 2, 3, 4, 5, 10, 12, 13, 9, 8}.
    ASCENDC_TPL_BOOL_DECL(IS_SPLIT, 0, 1), // Template argument declaration of the custom bool type: The template argument is a flag bit indicating whether to split. 1: split; 0: not split.
    );
    
    // Template argument selection
    // Used to check whether the TilingKey is valid when GET_TPL_TILING_KEY is called to obtain the TilingKey.
    ASCENDC_TPL_SEL(
        ASCENDC_TPL_ARGS_SEL(
        ASCENDC_TPL_KERNEL_TYPE_SEL(ASCENDC_TPL_AIV_ONLY), // Kernel type selection, which does not need to be defined in the template parameter declaration and is directly configured in the SEL. The configuration of all ASCENDC_TPL_ARGS_SEL must be consistent. If the configuration is not performed, the automatic derivation process is performed.
        ASCENDC_TPL_DATATYPE_SEL(D_T_X, C_DT_FLOAT16),
        ASCENDC_TPL_DATATYPE_SEL(D_T_Y, C_DT_FLOAT16),
        ASCENDC_TPL_DATATYPE_SEL(D_T_Z, C_DT_FLOAT16),
        ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8),
        ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1)
        ),
        ASCENDC_TPL_ARGS_SEL(
        ASCENDC_TPL_KERNEL_TYPE_SEL(ASCENDC_TPL_AIV_ONLY),
        ASCENDC_TPL_DATATYPE_SEL(D_T_X, C_DT_FLOAT),
        ASCENDC_TPL_DATATYPE_SEL(D_T_Y, C_DT_FLOAT),
        ASCENDC_TPL_DATATYPE_SEL(D_T_Z, C_DT_FLOAT),
        ASCENDC_TPL_UINT_SEL(TILE_NUM, ASCENDC_TPL_UI_LIST, 1, 8),
        ASCENDC_TPL_BOOL_SEL(IS_SPLIT, 0, 1)
        ),
    );
    
  2. The host calls the ASCENDC_TPL_SEL_PARAM API to automatically generate and configure the TilingKey.
    • In the host implementation file, include the header file that defines the template argument declaration and template argument selection in Step 1.
    • The ASCENDC_TPL_SEL_PARAM API is called to automatically generate and configure the TilingKey. The input parameter of ASCENDC_TPL_SEL_PARAM is the specific value of the template parameter. When the parameter is transferred, the sequence of the template parameters must be the same as that in the header file defining the template parameters and template parameter combinations.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    #include "tiling_key_add_custom.h"
    static ge::graphStatus TilingFunc(gert::TilingContext *context)
    {
        TilingData tiling;
        uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize();
        ge::DataType dtype_x = context->GetInputDesc(0)->GetDataType();
        ge::DataType dtype_y = context->GetInputDesc(1)->GetDataType();
        ge::DataType dtype_z = context->GetOutputDesc(1)->GetDataType();
        uint32_t D_T_X = static_cast<int>(dtype_x), D_T_Y = static_cast<int>(dtype_y), D_T_Z = static_cast<int>(dtype_z), TILE_NUM = 1, IS_SPLIT = 0;
        if(totalLength< MIN_LENGTH_FOR_SPLIT){
            IS_SPLIT = 0;
            TILE_NUM = 1;
        }else{
            IS_SPLIT = 1;
            TILE_NUM = DEFAULT_TILE_NUM;
        }
        context->SetBlockDim(BLOCK_DIM);
        tiling.set_totalLength(totalLength);
        tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
        context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
        ASCENDC_TPL_SEL_PARAM(context, D_T_X, D_T_Y, D_T_Z, TILE_NUM, IS_SPLIT);
        size_t *currentWorkspace = context->GetWorkspaceSizes(1);
        currentWorkspace[0] = 0;
        return ge::GRAPH_SUCCESS;
    }
    
  3. Implement on the kernel.
    • In the host implementation file, include the header file that defines the template argument declaration and template argument selection in Step 1.
    • Add a template to the kernel function to support the input of template arguments. These arguments must be in the same sequence as those in the header file that defines the template argument declaration and template argument selection.
    • Select different kernels based on the branch judgment of template arguments.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    #include "tiling_key_add_custom.h"
    ...
    ...
    template<typename D_T_X, typename D_T_Y, typename D_T_Z, int TILE_NUM, int IS_SPLIT>
     __global__ __aicore__ void add_custom_template(GM_ADDR x, GM_ADDR y, GM_ADDR z, GM_ADDR workspace, GM_ADDR tiling)
    {
        GET_TILING_DATA(tiling_data, tiling);
        KernelAdd<D_T_X, D_T_Y, D_T_Z> op;
        op.Init(x, y, z, tiling_data.totalLength, TILE_NUM);
        if constexpr (std::is_same_v<D_T_X, float> && std::is_same_v<D_T_Y, float> && std::is_same_v<D_T_Z, float>) {
            op.Process1();
        } else if constexpr (std::is_same_v<D_T_X, half> && std::is_same_v<D_T_Y, half> && std::is_same_v<D_T_Z, half>){
            if (IS_SPLIT == 0) {
                op.Process1();
            } else if(IS_SPLIT == 1) {
                op.Process2();
            }
        }
    }
    

For details about the complete example, click tiling template programming sample.