注册一个tiling的结构用于保存所需tiling参数。
参数 |
输入/输出 |
说明 |
---|---|---|
class_name |
输入 |
用户定义tiling结构体名,与c++变量命名要求一致 |
参数 |
输入/输出 |
说明 |
---|---|---|
data_type |
输入 |
字段的数据类型 |
field_name |
输入 |
字段名,与c++变量命名要求一致 |
参数 |
输入/输出 |
说明 |
---|---|---|
arr_type |
输入 |
数组元素数据类型 |
arr_size |
输入 |
数组元素个数 |
field_name |
输入 |
字段名,与c++变量命名要求一致 |
参数 |
输入/输出 |
说明 |
---|---|---|
struct_type |
输入 |
结构体类型 |
field_name |
输入 |
字段名,与c++变量命名要求一致 |
#include "register/tilingdata_base.h" // 定义tilingdata类 namespace optiling { BEGIN_TILING_DATA_DEF(Matmul) TILING_DATA_FIELD_DEF(uint16_t, mmVar); TILING_DATA_FIELD_DEF_ARR(uint16_t, 3, mmArr); END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(MatmulOp, Matmul) //注册中间结构体 BEGIN_TILING_DATA_DEF(AddCustomTilingData) // 注册一个tiling类,以tiling的名字作为入参 TILING_DATA_FIELD_DEF(uint32_t, blkDim); // 添加tiling变量类型字段,参与计算核数 TILING_DATA_FIELD_DEF(uint32_t, totalSize); // 添加tiling变量类型字段,总计算数据量-输入shape大小 TILING_DATA_FIELD_DEF(uint32_t, splitTile); // 添加tiling变量类型字段,每个core处理的数据分块计算 TILING_DATA_FIELD_DEF_ARR(uint16_t, 3, arrSample); // 添加tiling数组类型字段 TILING_DATA_FIELD_DEF_STRUCT(Matmul, mm); // 添加tiling结构体类型字段 END_TILING_DATA_DEF; // 定义结束 // 注册算子tilingdata类到对应的AddCustom算子 REGISTER_TILING_DATA_CLASS(AddCustom, AddCustomTilingData) } // host侧设置参数值和使用tiling参数 static void TilingTik2AddInit(Tik2AddTilingData *tiling, uint32_t blockDim) { // 设置参数值 tiling->set_blkDim(blockDim); // 置值通用数据类型变量blockDim uint16_t arr[] = {10,2,8,2,3,4,5,2,1,2,4,4,5,}; tiling->set_arrSample(arr); // 置值通用数据类型数组变量arrSample,仅会复制arr数据的前三个数据,与TILING_DATA_FIELD_DEF_ARR中arr_size一致 tiling->mm.set_mmVar(1); // 置值嵌套结构体通用数据类型变量mmVar tiling->mm.set_mmArr(arr); // 置值嵌套结构体通用数据类型数组mmArr // 使用参数值 uint32_t useBlockDim = tiling->get_blkDim(); // 获取通用数据类型变量blockDim uint32_t* arrPoint = tiling->get_arrSample(); // 获取通用数据类型数组变量arrSample useBlockDim = tiling->mm.get_mmVar(); // 获取嵌套结构体通用数据类型变量mmVar arrPoint = tiling->mm.get_mmArr(); // 获取嵌套结构体通用数据类型数组mmArr }