GET_TPL_TILING_KEY
Function Usage
Automatically generates a TilingKey during tiling template programming. This API converts the passed template arguments into binary values based on the defined bit width, combines the binary values in sequence, and then converts the values into uint64, that is, TilingKey.
To use this API, you need to include the header file that contains the template argument declaration and template argument selection. For details, see Tiling Template Programming.
Prototype
1 2 3 4 5 6 7 8 |
namespace AscendC { uint64_t EncodeTilingKey(TilingDeclareParams declareParams, TilingSelectParams selectParamsVec, std::vector<uint64_t> tilingParams); } #define GET_TPL_TILING_KEY(...) \ AscendC::EncodeTilingKey(g_tilingDeclareParams, g_tilingSelectParams, {__VA_ARGS__}) // GET_TPL_TILING_KEY calls the EncodeTilingKey API to generate a TilingKey. EncodeTilingKey is an internal API and can be ignored. |
Parameters
|
Parameter |
Input/Output |
Description |
|---|---|---|
|
... |
Input |
Variable-length parameters, which are the specific values of template arguments. They should be passed in the same sequence as the template arguments in the header file that contains the template argument declaration and template argument selection. |
Returns
TilingKey value.
Constraints
None
Example
#include "tiling_key_add_custom.h"
static ge::graphStatus TilingFunc(gert::TilingContext *context)
{
TilingData tiling;
uint32_t totalLength = context->GetInputShape(0)->GetOriginShape().GetShapeSize();
ge::DataType dtype_x = context->GetInputDesc(0)->GetDataType();
ge::DataType dtype_y = context->GetInputDesc(1)->GetDataType();
ge::DataType dtype_z = context->GetOutputDesc(1)->GetDataType();
uint32_t D_T_X = ADD_TPL_FP32, D_T_Y=ADD_TPL_FP32, D_T_Z=ADD_TPL_FP32, TILE_NUM=1, IS_SPLIT=0;
if(dtype_x == ge::DataType::DT_FLOAT){
D_T_X = ADD_TPL_FP32;
}else if(dtype_x == ge::DataType::DT_FLOAT16){
D_T_X = ADD_TPL_FP16;
}
if(dtype_y == ge::DataType::DT_FLOAT){
D_T_Y = ADD_TPL_FP32;
}else if(dtype_y == ge::DataType::DT_FLOAT16){
D_T_Y = ADD_TPL_FP16;
}
if(dtype_z == ge::DataType::DT_FLOAT){
D_T_Z = ADD_TPL_FP32;
}else if(dtype_z == ge::DataType::DT_FLOAT16){
D_T_Z = ADD_TPL_FP16;
}
if(totalLength< MIN_LENGTH_FOR_SPLIT){
IS_SPLIT = 0;
TILE_NUM = 1;
}else{
IS_SPLIT = 1;
TILE_NUM = DEFAULT_TILE_NUM;
}
context->SetBlockDim(BLOCK_DIM);
tiling.set_totalLength(totalLength);
tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
const uint64_t tilingKey = GET_TPL_TILING_KEY(D_T_X, D_T_Y, D_T_Z, TILE_NUM, IS_SPLIT);
context->SetTilingKey(tilingKey);
size_t *currentWorkspace = context->GetWorkspaceSizes(1);
currentWorkspace[0] = 0;
return ge::GRAPH_SUCCESS;
}