Plugin Development

You can refer to this section to develop an operator plugin and map the operators of the ONNX framework to the operators that adapt to the Ascend AI Processor (CANN operators for short). In this way, Ascend C custom operators can be called from the ONNX framework.

After creating an operator project, the framework/onnx_plugin directory is generated in the operator project path to store the implementation file of the ONNX adaptation plugin. The following uses the Custom CANN operator LeakyReluCustom

as an example. The operator project directory is as follows:

LeakyReluCustom
├── build.sh             // Build script
├── cmake 
├── CMakeLists.txt       // Build script of the operator project
├── CMakePresets.json    // Build configuration options
├── framework            // Directory for storing the implementation file of the framework adaptation plugin
│   ├── onnx_plugin     // Directory for storing the implementation file of the ONNX adaptation plugin
│   │   ├── CMakeLists.txt    
│   │   ├── onnx_leaky_relu_custom_plugin.cc // Implementation file of the ONNX adaptation plugin
│   ├── CMakeLists.txt
├── op_host                      // Implementation file on the host
├── op_kernel                    // Implementation file on the kernel
└── scripts                      // Directory of scripts used for custom operator project packing
The following describes how to develop the implementation file (onnx_leaky_relu_custom_plugin.cc) of the ONNX adaptation plugin.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
#include "register/register.h"
#include "graph/operator.h"
#include "json.hpp"
namespace domi {
    Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest) {
        //...
    }
    REGISTER_CUSTOM_OP("OpType")
        .FrameworkType(ONNX) 
        .OriginOpType("OriginOpType")
        .ParseParamsByOperatorFn(ParseParamByOpFunc)   // Registers a function for parsing operator attributes.
        .ImplyType(ImplyType::TVM);  // Sets the implementation type of Ascend C operators to TVM.
}
  1. Include required header files.
    • register.h is stored in include/register/ under the CANN component directory. Inclusion of this header file enables calls to the operator registration APIs.
    • operator.h (optional) is stored in include/graph/ under the CANN software installation directory. Inclusion of this header file enables calls to the operator APIs, which can be used to obtain the operator information such as the inputs, outputs, and attributes.
    • json.hpp is used to parse ONNX parameter definitions of the string type into the JSON format. Click here to download json.hpp if this file is not provided in the sample project. Place it in any subdirectory under the project directory and include this header file.
  2. Use the REGISTER_CUSTOM_OP macro to register the mapping between the CANN operators and the ONNX operators. The methods are as follows:
    • REGISTER_CUSTOM_OP: registers custom operators. OpType indicates the operator type name, which must be the same as the value of OpType in Operator Prototype Definition.
    • FrameworkType: specifies the framework type. ONNX indicates that the original framework is ONNX.
    • OriginOpType: indicates the type of an operator in the original framework. For example, the custom operator OpTypeA corresponds to the ONNX OPP version opset_version=11, and ai.onnx::11::OpTypeA is passed. The supported ONNX versions range from 9 to 15.
    • ParseParamsByOperatorFn(ParseParamByOpFunc): registers a callback function for parsing operator attributes to implement the mapping. You need to implement the ParseParamByOpFunc callback function. For details, see 3.
    • ImplyType: specifies the operator implementation type. Set the implementation type of Ascend C operators to TVM.
  3. Implement the ParseParamByOpFunc callback function. The function declaration is as follows:
    1
    Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest)
    
    • ParseParamByOpFunc: function name, which is user-defined.
    • op_src: an Operator class object defined by the ONNX framework, including attributes of the operator in the ONNX model. The definition is obtained from the original ONNX model file.
    • op_dest: data structure of the CANN operator, storing operator information. For details about the class Operator, see Operator.

    You need to implement attribute parsing and mapping in the callback function as follows:

    In the original ONNX model, for parameters of the repeated message type, you can obtain attribute values using the GetAttr(const char *name, ge::AscendString &attr_value) API, cast the attribute values of type AscendString to strings, and then convert the attribute values to the JSON format for attribute field parsing.

    The implementation is as follows.

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    Status ParseParamLeakyReluAscend(const ge::Operator& op_src, ge::Operator& op_dest) {
        float negative_slope = 0.01f;
        string negative_slope_str;
        AscendString attrs_string;
        // Obtain the attributes of an ONNX operator using the specified attribute name attribute and assign a value to the object of the AscendString type.
        if (ge::GRAPH_SUCCESS == op_src.GetAttr("attribute", attrs_string)) {
          // Cast to the JSON format.
          json attrs = json::parse(attrs_string.GetString());
          for (json attr : attrs["attribute"]) {
            if (attr["name"] == "alpha" && attr["type"] == kTypeFloat) {
              negative_slope_str = attr["f"];  // float type in json has accuracy loss, so we use string type to store it
              negative_slope = atof(negative_slope_str.c_str());
            }
          }
        }
        op_dest.SetAttr("negative_slope", negative_slope);
        return SUCCESS;
    }
    
    • The GetAttr and SetAttr APIs of the current version cannot parse fields of type double or uint64 in the original file.
    • During model conversion using the ATC tool, strong verification is not performed on the obtaining of attributes. When implementing an operator plugin, it is advisable to add the corresponding processing logic for possible GetAttr call failures. For example, return a failure message for a required attribute or prompt the user to set a default value for an optional attribute.