atb_graph_op.cpp
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | #include "atb/atb_graph_op.h" #include "utils/utils.h" atb::Status CreateGraphOperation(atb::Operation **operation) { // 构图流程 // 图算子的输入a,b,c,d // 计算公式:(a+b) + (c+d) // 输入是4个参数,输出是1个参数,有3个add算子,中间产生的临时输出是2个 atb::GraphParam opGraph; opGraph.inTensorNum = 4; opGraph.outTensorNum = 1; opGraph.internalTensorNum = 2; opGraph.nodes.resize(3); enum InTensorId { // 定义各TensorID IN_TENSOR_A = 0, IN_TENSOR_B, IN_TENSOR_C, IN_TENSOR_D, ADD3_OUT, ADD1_OUT, ADD2_OUT }; size_t nodeId = 0; atb::Node &addNode = opGraph.nodes.at(nodeId++); atb::Node &addNode2 = opGraph.nodes.at(nodeId++); atb::Node &addNode3 = opGraph.nodes.at(nodeId++); atb::Operation *op = nullptr; atb::infer::ElewiseParam addParam; addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; auto status = atb::CreateOperation(addParam, &addNode.operation); CHECK_RET(status, "addParam CreateOperation failed. status: " + std::to_string(status)); addNode.inTensorIds = {IN_TENSOR_A, IN_TENSOR_B}; addNode.outTensorIds = {ADD1_OUT}; atb::infer::ElewiseParam addParam2; addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; status = atb::CreateOperation(addParam2, &addNode2.operation); CHECK_RET(status, "addParam2 CreateOperation failed. status: " + std::to_string(status)); addNode2.inTensorIds = {IN_TENSOR_C, IN_TENSOR_D}; addNode2.outTensorIds = {ADD2_OUT}; atb::infer::ElewiseParam addParam3; addParam3.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; status = CreateOperation(addParam3, &addNode3.operation); CHECK_RET(status, "addParam3 CreateOperation failed. status: " + std::to_string(status)); addNode3.inTensorIds = {ADD1_OUT, ADD2_OUT}; addNode3.outTensorIds = {ADD3_OUT}; status = atb::CreateOperation(opGraph, operation); CHECK_RET(status, "GraphParam CreateOperation failed. status: " + std::to_string(status)); return atb::NO_ERROR; } |
父主题: 用例源码