Operator Prototype Definition
The following uses the reshape_cust operator in the operator development sample project as an example to describe the AI CPU operator development workflow.
Go to the op_proto directory of the operator project, compile the IR implementation files reshape_cust.h and reshape_cust.cc, and register the operator with the operator prototype library. During network execution, GE calls the verification API of the operator prototype library to verify operator arguments. If the verification passes, GE infers the output shape and dtype of each node by calling the inference function of the operator prototype library and allocates static memory for the result tensor.
reshape_cust.h
MindStudio generates the operator registration code to the reshape_cust.h file. You can modify the code as required. The prototype definition of the ReshapeCust operator is as follows.
#ifndef GE_OP_INTERP_RESHAPE_CUST_H
#define GE_OP_INTERP_RESHAPE_CUST_H
#include "graph/operator_reg.h"
namespace ge {
REG_OP(ReshapeCust)
.INPUT(tensor, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8,
DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32}))
.INPUT(shape, TensorType({DT_INT32, DT_INT64}))
.OUTPUT(output, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8,
DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32}))
.OP_END_FACTORY_REG(ReshapeCust)
}
#endif // GE_OP_INTERP_RESHAPE_CUST_H
- Add in REG_OP(ReshapeCust) indicates the type of the operator registered with the the Ascend AI Processor. The type of the operator must be the same as that in REGISTER_CUSTOM_OP("ReshapeCust") in Operator Plugin Implementation (TensorFlow).
- .INPUT and .OUTPUT indicate the names and data types of the input and output tensors of the operator. The input and output sequence must be consistent with the function parameter sequence in Operator Code Implementation as well as that in Operator Information Library Definition.
reshape_cust.cc Implementation
The key of prototype definition is to infer the output shape. The principle of the output shape of the ReshapeCust operator is as follows: Obtain the input tensor and the target shape, check whether the element count of the input tensor is the same as that of the target shape. If yes, the target shape is set to the output shape.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #include "reshape_cust.h" // IR registration header file #include <vector> // Vector templates can be used and APIs of class vector can be called. #include <string> // The string class is part of the C++ standard library. With the header file included, string objects can be constructed and string APIs can be called. #include <iostream> // The iostream class is part of the C++ standard library. With the header file included, input and output stream APIs can be called. namespace { // Obtain the target shape. template <typename T> std::vector<int64_t> AsInt64(const T *data, int64_t dataSize) { std::vector<int64_t> ret(dataSize); for (int64_t i = 0; i < dataSize; ++i) { ret[i] = data[i]; } return ret; } // Obtain the element count based on the shape. int64_t GetElementNum(const std::vector<int64_t> &shape) { int64_t ret = 1; for (size_t i = 0; i < shape.size(); ++i) { ret *= shape[i]; } return ret; } } namespace ge { IMPLEMT_COMMON_INFERFUNC(ReshapeCustInferShape) { TensorDesc tensordesc_tensor = op.GetInputDescByName("tensor"); TensorDesc tensordesc_shape = op.GetInputDescByName("shape"); TensorDesc tensordesc_output = op.GetOutputDescByName("output"); Tensor shape_tensor; // Obtain the value of the target shape. if (op.GetInputConstData("shape", shape_tensor) == GRAPH_SUCCESS) { DataType shape_type = tensordesc_shape.GetDataType(); std::vector<int64_t> shape_values; if (shape_type == DT_INT32) { auto shape_data = reinterpret_cast<const int32_t *>(shape_tensor.GetData()); shape_values = AsInt64<int32_t>(shape_data, shape_tensor.GetSize()/sizeof(int32_t)); } else { auto shape_data = reinterpret_cast<const int64_t *>(shape_tensor.GetData()); shape_values = AsInt64<int64_t>(shape_data, shape_tensor.GetSize()/sizeof(int64_t)); } // Check whether the target shape is valid. std::vector<int64_t> input_shape = tensordesc_tensor.GetShape().GetDims(); int64_t input_element_num = GetElementNum(input_shape); int64_t shape_element_num = GetElementNum(shape_values); if (input_element_num != shape_element_num) { return GRAPH_FAILED; } // Set the shape of the output tensor. tensordesc_output.SetShape(Shape(shape_values)); tensordesc_output.SetOriginShape(Shape(shape_values)); } tensordesc_output.SetDataType(tensordesc_tensor.GetDataType()); std::vector<std::pair<int64_t,int64_t>> range; auto status = op.GetInputDesc("tensor").GetShapeRange(range); if (status != GRAPH_SUCCESS) { return GRAPH_FAILED; } tensordesc_output.SetShapeRange(range); (void)op.UpdateOutputDesc("output", tensordesc_output); return GRAPH_SUCCESS; } COMMON_INFER_FUNC_REG(ReshapeCust, ReshapeCustInferShape); } |