TilingData Structure Registration

Function Usage

Registers the defined TilingData structure and binds it with a custom operator. For details, see Example.

Prototype

REGISTER_TILING_DATA_CLASS(op_type, class_name)

#define REGISTER_TILING_DATA_CLASS(op_type, class_name)
  class op_type##class_name##Helper {
  public:
    op_type##class_name##Helper() {
      CTilingDataClassFactory::RegisterTilingData(#op_type, op_type##class_name##Helper::CreateTilingDataInstance);
    }
    static std::shared_ptr<TilingDef> CreateTilingDataInstance() {
      return std::make_shared<class_name>();
    }
  };
  op_type##class_name##Helper g_tilingdata_##op_type##class_name##helper;

Parameters

Table 1 Parameters

Parameter

Input/Output

Description

op_type

Input

Registered operator name.

struct_name

Input

Tiling structure name, whose naming rule must be consistent with that of the C++ variable name.

Constraints

  • The header file register/tilingdata_base.h must be included during function use.
  • Pay attention to the op_type naming rule for the intermediate structure and custom TilingKey structure. For details, see Example.
  • When customizing the TilingKey structure for an operator, ensure that the default structure of op_type is registered.
  • The tiling structure is a global attribute. Note that the structure name must be used as the globally unique identifier. If different operators register different tiling structures with the same name, undefined behavior occurs.

Example

  • Register an operator tiling structure.
    #include "register/tilingdata_base.h"
    
    // Define the TilingData class.
    namespace optiling {
    BEGIN_TILING_DATA_DEF(AddCustomTilingData)    // Register a tiling class and uses the tiling name as the input parameter.
      TILING_DATA_FIELD_DEF(uint32_t, blkDim);    // Add a tiling field to compute the number of cores.
      TILING_DATA_FIELD_DEF(uint32_t, totalSize); // Add a tiling field to compute the total data size.
      TILING_DATA_FIELD_DEF(uint32_t, splitTile); // Add a tiling field and compute data processed by each core by block.
    END_TILING_DATA_DEF;                          // Definition ends.
    // Register the operator TilingData class with the corresponding AddCustom operator.
    REGISTER_TILING_DATA_CLASS(AddCustom, AddCustomTilingData) 
    }
  • Register an intermediate structure. In structure nesting scenarios, the nested structure is called the intermediate structure. Since only one tiling structure can be registered for a given operator name, in order for the framework to detect intermediate structure information, it is necessary to construct a 'virtual operator name' (structure name + Op) and register the intermediate structure through the REGISTER_TILING_DATA_CLASS API. The registration method is as follows:
    BEGIN_TILING_DATA_DEF(Matmul)
      TILING_DATA_FIELD_DEF(uint16_t, mmVar);
      TILING_DATA_FIELD_DEF_ARR(uint16_t, 3, mmArr);
    END_TILING_DATA_DEF;
    // Register an intermediate structure. The first parameter is fixed to struct_name#Op, and the second parameter is struct_name. For example, if struct_name is Matmul, the first parameter is MatmulOp and the second parameter is Matmul.
    REGISTER_TILING_DATA_CLASS(MatmulOp, Matmul)      // Register an intermediate structure.
  • Customize the tiling_key to register different tiling structures.
    The first parameter in /*REGISTER_TILING_DATA_CLASS is ${op_type} + '_' + tiling_key. If no matching tiling structure is registered for tiling_key, the default structure is used. If tiling_key is not specified or is not set to 1, the tiling structure is AddStruct. If tiling_key is set to 1, the tiling structure is AddStructSample1*/.
    
    // Take op_type as Add as an example. The default tiling structure is registered as follows:
    BEGIN_TILING_DATA_DEF(AddStruct)   
      TILING_DATA_FIELD_DEF(uint16_t, mmVar);   
      TILING_DATA_FIELD_DEF_ARR(uint16_t, 3, mmArr); 
    END_TILING_DATA_DEF; 
    REGISTER_TILING_DATA_CLASS(Add, AddStruct) 
    
    // If TilingKey is set to 1, the structure is registered as follows:
    BEGIN_TILING_DATA_DEF(AddStructSample1) 
      TILING_DATA_FIELD_DEF(uint16_t, mmVar);   
      TILING_DATA_FIELD_DEF_ARR(uint16_t, 3, mmArr); 
    END_TILING_DATA_DEF; 
    REGISTER_TILING_DATA_CLASS(Add_1, AddStructSample1)