#include <acl/acl.h> #include <atb/atb_infer.h>
int deviceId = 0; aclError status = aclrtSetDevice(deviceId);
// 以elewise大类中的Add算子为例,可通过以下方式构造对应参数: atb::infer::ElewiseParam param; param.elewiseType = atb::infer::ElewiseParam::ELEWISE_ADD;
atb::Operation *op = nullptr; atb::Status st = atb::CreateOperation(param, &op);
图算子有配置TensorId和配置TensorName组图两种创建和使用方式。图算子结构图解请参考图1,TensorId与TensorName对应关系配置如表1。
Tensor |
TensorId |
TensorName |
---|---|---|
a |
0 |
"a" |
b |
1 |
"b" |
c |
2 |
"c" |
output |
3 |
"output" |
a_add_b_output |
4 |
"a_add_b_output" |
与单算子的参数不同,图算子的参数包含图节点、输入tensor数、输出tensor数、中间tensor数等图相关的信息。
首先,根据设计的图算子结构,分别计算出图输入tensor(假设为x个),图输出tensor(假设为y个)以及图中间tensor(假设为z个)的个数。 图输入tensor的Id取值为[0, x - 1],图输出tensor的Id取值为[x, x + y - 1],图中间tensor的Id取值为[x + y, x + y + z - 1]。示例对应关系见表1Tensor与TensorId列。
然后,配置每一个节点的相关信息,包括创建好的单算子对象实例、输入tensor和输出tensor。该节点的输入和输出tensor在图里可能是图的输入tensor、输出tensor或中间tensor,用户需根据其所属的图Tensor类型,在合适的范围内取值。
实例中的op0和op1创建过程可参考单算子的创建。
atb::GraphParam graphParam; graphParam.inTensorNum = 3; // 指定该图的输入tensor数量 graphParam.outTensorNum = 1; // 指定该图的输出tensor数量 graphParam.internalTensorNum = 1; // 指定该图的中间tensor数量 graphParam.nodes.resize(2); // 指定该图中的节点数量,即包含的单算子数量 graphParam.nodes[0].operation = op0; // 指定该图中的节点0的单算子对象实例 graphParam.nodes[0].inTensorIds = {0, 1}; // 指定该图中的节点0需要的输入tensor所对应的id graphParam.nodes[0].outTensorIds = {4}; // 指定该图中的节点0输出的输出tensor所对应的id graphParam.nodes[1].operation = op1; // 指定该图中的节点1的单算子对象实例 graphParam.nodes[1].inTensorIds = {4, 2}; // 指定该图中的节点1需要的输入tensor所对应的id graphParam.nodes[1].outTensorIds = {3}; // 指定该图中的节点1输出的输出tensor所对应的id
atb::Operation *op = nullptr; atb::Status st = atb::CreateOperation(graphParam, &op);
atb::GraphOpBuilder* graphOpBuilder; CreateGraphOpBuilder(&graphOpBuilder);
// lambda函数,通过图算子的输入tensorDesc推导输出tensorDesc,包括DataType、Format、Shape等 atb::InferShapeFunc inferShapeFunc = [=](const atb::SVector<atb::TensorDesc> &inTensorDescs, atb::SVector<atb::TensorDesc> &outTensorDescs) { outTensorDescs.at(0) = inTensorDescs.at(0); return atb::NO_ERROR; }; graphOpBuilder->Init("DemoGraphOperation", inferShapeFunc, {"a", "b", "c"}, {"output"});
构图时可通过定义lambda函数对Tensor进行reshape,需保证reshape前后的shape大小一致。
op0等单算子的创建过程可参考单算子的创建。
graphOpBuilder->AddOperation(op0, {"a", "b"}, {"a_add_b_output"}); graphOpBuilder->AddOperation(op1, {"a_add_b_output", "c"}, {"output"});
atb::Operation *op = graphOpBuilder->Build(); // 使用时需判断op是否为空指针 DestroyGraphOpBuilder(graphOpBuilder); // 销毁图算子构造器
在使用插件机制时,用户需要自行管理因编写的插件代码导致的安全性或系统不可控行为。请确保编写的代码可靠并遵循相关安全规范和最佳实践。
// Tensor构造方法 atb::Tensor a; a.desc.dtype = ACL_FLOAT16; // 配置Tensor数据类型 a.desc.format = ACL_FORMAT_ND; // 配置Tensor格式 a.desc.shape.dimNum = 2; // 配置Tensor维度数 a.desc.shape.dims[0] = 3; // 配置Tensor第0维大小 a.desc.shape.dims[1] = 3; // 配置Tensor第1维大小 a.dataSize = Utils::GetTensorSize(a); // 获取Tensor内存大小 status = aclrtMalloc(&a.deviceData, a.dataSize, ACL_MEM_MALLOC_HUGE_FIRST); // 申请device内存 // 按上述方法构造所有输入和输出tensor,存入VariantPack atb::VariantPack variantPack; variantPack.inTensors = { a, ... }; variantPack.outTensors = { output, ... };
atb::Context *context = nullptr; st = atb::CreateContext(&context); aclrtStream stream = nullptr; status = aclrtCreateStream(&stream); context->SetExecuteStream(stream);
uint64_t workspaceSize = 0; st = op->Setup(variantPack, workspaceSize, context);
void *workspace = nullptr; status = aclrtMalloc(&workspace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);
当workspace大小为0时,无需执行该步骤,否则会报错。
st = op->Execute(variantPack, (uint8_t *)workspace, workspaceSize, context);
status = aclrtDestroyStream(stream); // 销毁stream status = aclrtFree(workspace); // 销毁workspace st = atb::DestroyOperation(op); // 销毁op对象 st = atb::DestroyContext(context); // 销毁context // 下面代码为释放Tensor的示例代码,实际使用时需释放VariantPack中的所有Tensor status = aclrtFree(tensor.deviceData); tensor.deviceData = nullptr; tensor.dataSize = 0;
# g++编译demo工程,demo.cpp为demo对应的源码文件 g++ -I "${ATB_HOME_PATH}/include" -I "${ASCEND_HOME_PATH}/include" -L "${ATB_HOME_PATH}/lib" -L "${ASCEND_HOME_PATH}/lib64" demo.cpp -l atb -l ascendcl -o demo ./demo # 运行可执行文件
当abi=0时,需在g++命令中添加编译选项-D_GLIBCXX_USE_CXX11_ABI=0。abi的说明请参考nottoctopics/zh-cn_topic_0000002289324910.html#ZH-CN_TOPIC_0000002289324910__Software-cannToolKit-cannNNAE步骤3说明部分。