Operator Prototype Definition
The following uses the Add operator in the operator development sample project as an example to describe the TBE operator development workflow.
Go to the op_proto/ directory, write the IR implementation files add.h and add.cc, and register the operator with the operator prototype library. During network execution, GE calls the verification API of the operator prototype library to verify operator arguments. If the verification passes, GE infers the output shape and dtype of each node by calling the inference function of the operator prototype library and allocates static memory for the result tensor.
Implementing add.h
MindStudio generates the operator registration code to the add.h file. You can modify the code as required. The prototype definition of the Add operator is as follows.
#ifndef GE_OPS_OP_PROTO_ADD_H_
#define GE_OPS_OP_PROTO_ADD_H_
#include "graph/operator_reg.h"
namespace ge {
REG_OP(Add)
.INPUT(x1,
TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.INPUT(x2,
TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OUTPUT(y,
TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OP_END_FACTORY_REG(Add)
}
#endif //GE_OPS_OP_PROTO_ADD_H
- Add in REG_OP(Add) indicates the type of the operator registered with the the Ascend AI Processor. The type of the operator in the third-party framework (TensorFlow, ONNX, or Caffe) must be the same as that in REGISTER_CUSTOM_OP("Add") in Operator Plugin Implementation (TensorFlow/Caffe/ONNX).
- .INPUT and .OUTPUT indicate the names and data types of the input and output tensors of the operator. The input and output sequence must be consistent with the function parameter sequence in Operator Code Implementation as well as that in Operator Information Library Definition.
add.cc Implementation
You need to implement the InferShape and Verify functions in add.cc.
- The Verify function, that is, IMPLEMT_VERIFIER(Add, AddVerify) in the following sample code, is used to check whether the data types of the two inputs of the Add operator are the same.
- The InferShape function that is, IMPLEMT_COMMON_INFERFUNC(AddInferShape) in the following sample code, is used to infer the output tensor description of the operator. In this way, the memory can be statically allocated for all tensors during network execution, avoiding the overhead caused by dynamic memory allocation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | #include "./add.h" // IR registration header file #include <vector> // Vector templates can be used and APIs of class vector can be called. #include <string> // The string class is part of the C++ standard library. With the header file included, string objects can be constructed and string APIs can be called. namespace ge { bool InferShapeAndTypeAdd(Operator& op, const string& inputName1, const string& inputName2, const string& outputName) { TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str()); DataType inputDtype = op.GetInputDescByName(inputName1.c_str()).GetDataType(); Format inputFormat = op.GetInputDescByName(inputName1.c_str()).GetFormat(); // Exchange the shape dimensions. ge::Shape shapeX = op.GetInputDescByName(inputName1.c_str()).GetShape(); ge::Shape shapeY = op.GetInputDescByName(inputName2.c_str()).GetShape(); std::vector<int64_t> dimsX = shapeX.GetDims(); std::vector<int64_t> dimsY = shapeY.GetDims(); if (dimsX.size() < dimsY.size()) { std::vector<int64_t> dimsTmp = dimsX; dimsX = dimsY; dimsY = dimsTmp; } // Pad the smaller shape dimension with 1. if (dimsX.size() != dimsY.size()) { int dec = dimsX.size() - dimsY.size(); for (int i = 0; i < dec; i++) { dimsY.insert(dimsY.begin(), (int64_t)1); } } // Set the output shape dimension. std::vector<int64_t> dimVec; for (size_t i = 0; i < dimsX.size(); i++) { if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) { return false; } int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; dimVec.push_back(dims); } ge::Shape outputShape = ge::Shape(dimVec); vOutputDesc.SetShape(outputShape); vOutputDesc.SetDataType(inputDtype); vOutputDesc.SetFormat(inputFormat); op.UpdateOutputDesc(outputName.c_str(), vOutputDesc); return true; } //----------------Add------------------- IMPLEMT_VERIFIER(Add, AddVerify) { if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } // Obtain the processing function of the output tensor description. IMPLEMT_COMMON_INFERFUNC(AddInferShape) { if(InferShapeAndTypeAdd(op, "x1", "x2", "y")) { return GRAPH_SUCCESS; } return GRAPH_FAILED; } // Registered inference function. Pass the OpType as the first argument. COMMON_INFER_FUNC_REG(Add, AddInferShape); // Registered verification function. Pass the OpType as the first argument. VERIFY_FUNC_REG(Add, AddVerify); //----------------Add------------------- } |