Tiling Implementation When Using High-Level APIs
- Define the tiling structure.
1 2 3 4 5 6 7
namespace optiling { BEGIN_TILING_DATA_DEF(MyAddTilingData) // Declare the name of the tiling structure. TILING_DATA_FIELD_DEF_STRUCT(TCubeTiling, cubeTilingData); // Reference to the tiling structure of the high-level API. TILING_DATA_FIELD_DEF(uint32_t, field); // Reference structure of the structure member. END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(MyAdd, MyAddTilingData) // Register the tiling structure with the operator. }
- Use the tiling function of the high-level API to initialize the tiling structure.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
static ge::graphStatus TilingFunc(gert::TilingContext* context) { int32_t M = 1024; int32_t N = 640; int32_t K = 256; int32_t baseM = 128; int32_t baseN = 128; auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); MultiCoreMatmulTiling cubeTiling(ascendcPlatform); cubeTiling.SetDim(2); cubeTiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); cubeTiling.SetBType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); cubeTiling.SetCType(TPosition::LCM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); cubeTiling.SetBiasType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); cubeTiling.SetShape(M, N, K); cubeTiling.SetOrgShape(M, N, K); cubeTiling.SetFixSplit(baseM, baseN, -1); cubeTiling.SetBias(true); cubeTiling.SetBufferSpace(-1, -1, -1); MyAddTilingData tiling; if (cubeTiling.GetTiling(tiling.cubeTilingData) == -1){ return ge::GRAPH_FAILED; } // some code }
Parent topic: Tiling Implementation on the Host