// test.cpp
#include <iostream>
#include "exe_graph/runtime/storage_shape.h"
#include "tiling/context/context_builder.h"
int main()
{
gert::StorageShape x_shape = {{2, 32}, {2, 32}};
gert::StorageShape y_shape = {{2, 32}, {2, 32}};
gert::StorageShape z_shape = {{2, 32}, {2, 32}};
auto param = gert::TilingData::CreateCap(4096);
auto workspace_size_holer = gert::ContinuousVector::Create<size_t>(4096);
auto ws_size = reinterpret_cast<gert::ContinuousVector *>(workspace_size_holer.get());
auto holder = context_ascendc::ContextBuilder()
.NodeIoNum(2, 1)
.IrInstanceNum({1, 1})
.AddInputTd(0, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND, x_shape)
.AddInputTd(1, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND, y_shape)
.AddOutputTd(0, ge::DT_FLOAT, ge::FORMAT_ND, ge::FORMAT_ND, z_shape)
.TilingData(param.get())
.Workspace(ws_size)
.AddPlatformInfo("Ascendxxxyy")
.BuildTilingContext();
auto tilingContext = holder->GetContext<gert::TilingContext>();
context_ascendc::OpTilingRegistry tmpIns;
bool flag = tmpIns.LoadTilingLibrary("/your/path/to/so_path/liboptiling.so"); // 加载对应的Tiling动态库文件
if (flag == false) {
std::cout << "Failed to load tiling so" << std::endl;
return -1;
}
context_ascendc::TilingFunc tilingFunc = tmpIns.GetTilingFunc("AddCustom"); // 获取AddCustom算子对应的Tiling函数, 此处入参为OpType
if (tilingFunc != nullptr) {
ge::graphStatus ret = tilingFunc(tilingContext); // 执行Tiling函数
if (ret != ge::GRAPH_SUCCESS) {
std::cout << "Exec tiling func failed." << std::endl;
return -1;
}
} else {
std::cout << "Get tiling func failed." << std::endl;
return -1;
}
return 0;
}