Operator Prototype Definition

The following uses the Add operator in the operator development sample project as an example to describe the TBE operator development workflow.

Go to the op_proto/ directory, write the IR implementation files add.h and add.cc, and register the operator with the operator prototype library. During network execution, GE calls the verification API of the operator prototype library to verify operator arguments. If the verification passes, GE infers the output shape and dtype of each node by calling the inference function of the operator prototype library and allocates static memory for the result tensor.

Implementing add.h

MindStudio generates the operator registration code to the add.h file. You can modify the code as required. The prototype definition of the Add operator is as follows.

#ifndef GE_OPS_OP_PROTO_ADD_H_
#define GE_OPS_OP_PROTO_ADD_H_
#include "graph/operator_reg.h"
namespace ge {
REG_OP(Add)
    .INPUT(x1,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .INPUT(x2,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .OUTPUT(y,
        TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
            DT_COMPLEX64, DT_STRING}))
    .OP_END_FACTORY_REG(Add)
}

#endif //GE_OPS_OP_PROTO_ADD_H

add.cc Implementation

You need to implement the InferShape and Verify functions in add.cc.

  • The Verify function, that is, IMPLEMT_VERIFIER(Add, AddVerify) in the following sample code, is used to check whether the data types of the two inputs of the Add operator are the same.
  • The InferShape function that is, IMPLEMT_COMMON_INFERFUNC(AddInferShape) in the following sample code, is used to infer the output tensor description of the operator. In this way, the memory can be statically allocated for all tensors during network execution, avoiding the overhead caused by dynamic memory allocation.
The implementation code of the add.cc file is as follows.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include "./add.h"         // IR registration header file
#include <vector>        // Vector templates can be used and APIs of class vector can be called.
#include <string>        // The string class is part of the C++ standard library. With the header file included, string objects can be constructed and string APIs can be called.

namespace ge {
bool InferShapeAndTypeAdd(Operator& op,
    const string& inputName1,
    const string& inputName2, const string& outputName)
 {
    TensorDesc vOutputDesc = op.GetOutputDescByName(outputName.c_str());

    DataType inputDtype = op.GetInputDescByName(inputName1.c_str()).GetDataType();
    Format inputFormat = op.GetInputDescByName(inputName1.c_str()).GetFormat();
    // Exchange the shape dimensions.
    ge::Shape shapeX = op.GetInputDescByName(inputName1.c_str()).GetShape();
    ge::Shape shapeY = op.GetInputDescByName(inputName2.c_str()).GetShape();
    std::vector<int64_t> dimsX = shapeX.GetDims();
    std::vector<int64_t> dimsY = shapeY.GetDims();
    if (dimsX.size() < dimsY.size()) {
        std::vector<int64_t> dimsTmp = dimsX;
        dimsX = dimsY;
        dimsY = dimsTmp;
    }

  // Pad the smaller shape dimension with 1.
  if (dimsX.size() != dimsY.size()) {
    int dec = dimsX.size() - dimsY.size();
    for (int i = 0; i < dec; i++) {
      dimsY.insert(dimsY.begin(), (int64_t)1);
    }
  }

    // Set the output shape dimension.
    std::vector<int64_t> dimVec;
    for (size_t i = 0; i < dimsX.size(); i++) {
        if ((dimsX[i] != dimsY[i]) && (dimsX[i] != 1) && (dimsY[i] != 1)) {
            return false;
        }

        int64_t dims = dimsX[i] > dimsY[i] ? dimsX[i] : dimsY[i];
        dimVec.push_back(dims);
    }
    ge::Shape outputShape = ge::Shape(dimVec);

    vOutputDesc.SetShape(outputShape);
    vOutputDesc.SetDataType(inputDtype);
    vOutputDesc.SetFormat(inputFormat);
    op.UpdateOutputDesc(outputName.c_str(), vOutputDesc);

    return true;
}

//----------------Add-------------------
IMPLEMT_VERIFIER(Add, AddVerify)
{
    if (op.GetInputDescByName("x1").GetDataType() != op.GetInputDescByName("x2").GetDataType()) {
        return GRAPH_FAILED;
    }
    return GRAPH_SUCCESS;
}


// Obtain the processing function of the output tensor description.
IMPLEMT_COMMON_INFERFUNC(AddInferShape)
{
    if(InferShapeAndTypeAdd(op, "x1", "x2", "y")) {
        return GRAPH_SUCCESS;
    }
    return GRAPH_FAILED;
}


// Registered inference function. Pass the OpType as the first argument.
COMMON_INFER_FUNC_REG(Add, AddInferShape);      


// Registered verification function. Pass the OpType as the first argument.
VERIFY_FUNC_REG(Add, AddVerify);
//----------------Add-------------------
}