Operator Prototype Definition
Overview
- Register the operator prototype in the operator_name.h header file.
- Implement the verification function and shape inference function in the operator_name.cc file.
This section uses the Add operator in the TBE operator development sample project as an example to describe how to define an operator prototype.
Go to the op_proto directory, write the IR implementation files add.h and add.cc, and register the operator with the operator prototype library. During network execution, GE calls the verification API of the operator prototype library to verify operator arguments. If the verification passes, GE infers the output shape and dtype of each node by calling the inference function of the operator prototype library and allocates static memory for the result tensor.
For details about how to define the prototype of an AI CPU operator, see the reshape_cust operator in the corresponding operator development sample project. The IR implementation files are reshape_cust.h and reshape_cust.cc.
add.h Implementation
MindStudio generates the operator registration code (template file) to the add.h header file. You can modify the code as required. The prototype definition of the Add operator is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | #ifndef GE_OPS_OP_PROTO_ADD_H_ // Build a condition. #define GE_OPS_OP_PROTO_ADD_H_ // Define a macro. #include "graph/operator_reg.h" namespace ge { REG_OP(Add) // Operator type name. .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING})) .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING})) .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, DT_COMPLEX64, DT_STRING})) .OP_END_FACTORY_REG(Add) } #endif //GE_OPS_OP_PROTO_ADD_H |
- Add in REG_OP(Add) indicates the type of the operator registered with the Ascend AI Processor. The type of the operator in a third-party framework (TensorFlow, ONNX, or Caffe) must be the same as that in REGISTER_CUSTOM_OP("Add") in Operator Plugin Implementation.
- .INPUT and .OUTPUT indicate the names and data types of the input and output tensors of the operator. The input and output sequence must be consistent with the function parameter sequence in Operator Code Implementation (TBE DSL) as well as that in Operator Information Library Definition.
add.cc Implementation
You need to implement the InferShape and Verify functions in add.cc.
- The Verify function, that is, IMPLEMT_VERIFIER(Add, AddVerify) in the following sample code, is used to check whether the data types of the two inputs of the Add operator are the same.
- The InferShape function that is, IMPLEMT_COMMON_INFERFUNC(AddInferShape) in the following sample code, is used to infer the output tensor description of the operator. In this way, the memory can be statically allocated for all tensors during network execution, avoiding the overhead caused by dynamic memory allocation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | #include "./add.h" // IR registration header file #include <vector> // Vector templates can be used and APIs of class vector can be called. #include <string> // The string class is part of the C++ standard library. With the header file included, string objects can be constructed and string APIs can be called. namespace ge { bool InferShapeAndTypeAdd(Operator& op, const string& inputName1, const string& inputName2, const string& outputName) { TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str()); DataType inputDtype = op.GetInputDescByName(inputName1.c_str()).GetDataType(); Format inputFormat = op.GetInputDescByName(inputName1.c_str()).GetFormat(); // Exchange the shape dimensions. ge::Shape shapeX = op.GetInputDescByName(inputName1.c_str()).GetShape(); ge::Shape shapeY = op.GetInputDescByName(inputName2.c_str()).GetShape(); std::vector<int64_t> dimsX = shapeX.GetDims(); std::vector<int64_t> dimsY = shapeY.GetDims(); if (dimsX.size() < dimsY.size()) { std::vector<int64_t> dimsTmp = dimsX; dimsX = dimsY; dimsY = dimsTmp; } // Pad the smaller shape dimension with 1. if (dimsX.size() != dimsY.size()) { int dec = dimsX.size() - dimsY.size(); for (int i = 0; i < dec; i++) { dimsY.insert(dimsY.begin(), (int64_t)1); } } // Set the output shape dimension. std::vector<int64_t> dimVec; for (size_t i = 0; i < dimsX.size(); i++) { if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) { return false; } int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i]; dimVec.push_back(dims); } ge::Shape outputShape = ge::Shape(dimVec); vOutputDesc.SetShape(outputShape); vOutputDesc.SetDataType(inputDtype); vOutputDesc.SetFormat(inputFormat); op.UpdateOutputDesc(outputName.c_str(), vOutputDesc); return true; } //----------------Add------------------- IMPLEMT_VERIFIER(Add, AddVerify) { if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } // Obtains the processing function of the output tensor description. IMPLEMT_COMMON_INFERFUNC(AddInferShape) { if(InferShapeAndTypeAdd(op, "x1", "x2", "y")) { return GRAPH_SUCCESS; } return GRAPH_FAILED; } // Registered inferfunction COMMON_INFER_FUNC_REG(Add, AddInferShape); // The first parameter is the OpType of the operator. // Registered verify function VERIFY_FUNC_REG(Add, AddVerify); // The first parameter is the OpType of the operator. //----------------Add------------------- } |