Operator Plugin Implementation (TensorFlow/ONNX)

You need to develop the operator plugin to interpret and map the operator in a third-party framework (TensorFlow or ONNX) to one adapted to the Ascend AI Processor.

TensorFlow

MindStudio automatically generates the plugin code of the ReshapeCust operator in the framework/tf_plugin/tensorflow_reshape_cust_plugin.cc file.

  • Include the header file.
    1
    2
    // Include the compiler/include/register/register.h file in the Ascend-CANN-Toolkit to use the operator registration class and call the operator registration APIs.
    #include "register/register.h"
    
  • Register the plugin.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    using namespace ge;    // Add it manually.
    namespace domi {
    // register op info to GE
    REGISTER_CUSTOM_OP("ReshapeCust")
        .FrameworkType(TENSORFLOW)   // type: CAFFE, TENSORFLOW
        .OriginOpType("ReshapeCust")      // name in tf module
        .ParseParamsByOperatorFn(AutoMappingByOpFn);
        .ImplyType(ImplyType::AI_CPU);    // Add it manually.
    }  // namespace domi
    
    • REGISTER_CUSTOM_OP: operator type registered with GE. According to the operator analysis, the operator type is ReshapeCust.
    • FrameworkType: framework type. The source framework type is TensorFlow.
    • OriginOpType: operator type in the TensorFlow framework.
    • ParseParamsByOperatorFn: function for registering models to be parsed. The AutoMappingFn function is used to automatically parse models.
    • ImplyType: implementation type of the operator. ImplyType::AI_CPU indicates that the operator is an AI CPU operator. Add it manually.

ONNX

MindStudio automatically generates the plugin code of the operator in the framework/onnx_plugin/xxx_plugin.cc file.

  • Include the header file.
    1
    2
    //Include the /compiler/include/register/register.h file in the Ascend-CANN-Toolkit installation directory/ascend-toolkit/latest to use the operator registration class and call the operator registration APIs.
    #include "register/register.h"
    
  • Register the plugin.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    using namespace ge;    // Add it manually.
    namespace domi {
    // Onnx ParseParams
    Status ParseParamAdd(const Message* op_src, ge::Operator& op_dest) {
        // To do: Implement the operator plugin by referring to the Onnx Operator Development Guide.
        return SUCCESS;
    }
    
    // register op info to GE
    REGISTER_CUSTOM_OP("Add")
        .FrameworkType(ONNX)   // Operator name with the original framework
        .OriginOpType("")      // Set the original frame type of the operator
        .ParseParamsByOperatorFn(ParseParamAdd)// Registering the callback function for parsing operator parameters 
        .ImplyType(ImplyType::TVM);   // Add it manually.
    }  // namespace domi
    
    • REGISTER_CUSTOM_OP: registers a custom operator. Add is the operator type registered with GE. The value cannot conflict with existing operator names and must be the same as that registered in OpType.
    • FrameworkType: framework type. ONNX indicates that the original framework is ONNX.
    • OriginOpType: type of the operator in the original framework. Type it manually. For example, custom operator Add corresponds to the original framework type ai.onnx::11::Add of the ONNX OPP whose opset_version is 11. In this case, set this parameter to OriginOpType("ai.onnx::11::Add").
    • ParseParamsByOperatorFn(ParseParamAdd): registers a function for parsing operator attributes. You need to implement the ParseParamAdd callback function.

      The ParseParamAdd callback function is declared as follows:

      Status ParseParamAdd(const ge::Operator& op_src, ge::Operator& op_dest)
      • ParseParamAdd: function name, which is user-defined and must be unique.
      • op_src: an Operator class object defined by the ONNX framework, including attributes of the operator in the ONNX model. The definition is obtained from the original ONNX model file.
      • op_dest: CANN operator data structure, which stores the operator information.
    • ImplyType: operator implementation type. ImplyType::TVM indicates that the operator is a TBE operator. Add it manually.