Implementing the Plugin of a Fused Operator
This topic describes how to implement a fused operator plugin to map the base operators in the original framework to a fused operator adapted to Ascend AI Processor and register the operator information with GE.
The corresponding functions are implemented in the implementation file (for example, decode_bbox_v2_scope_fussion_plugin.cc) of the fused operator plugin.
REGISTER_CUSTOM_OP("DecodeBboxV2")
.FrameworkType(TENSORFLOW) // The original framework is TensorFlow.
.OriginOpType("DecodeBboxV2FusionOp") // Type of the operator in the original framework, which is the same as the value of SetType in GenerateFusionResult.
.FusionParseParamsFn(DecodeBboxV2ParseParams) // Used to register the function for parsing the attributes of a fused operator.
.ImplyType(ImplyType::TVM); // Specifies the implementation mode of an operator. ImplyType::TVM indicates that the operator is a TBE operator.
For details about the REGISTER_CUSTOM_OP macro, ParseParamsByOperatorFn, FusionParseParamsFn (Overload), and other APIs, see "GE Namespace > Class OpRegistrationData" in Basic Data Structures and APIs. This section describes only the differences between the Scope fusion operator registration using Parser and common operator registration.
Unlike common operators, a fused operator is registered by calling FusionParseParamsFn (Overload), rather than ParseParamsByOperatorFn. The cause is that the input parameters of the two parser functions are different.
- The callback function prototype of the registration function ParseParamsByOperatorFn for common operators is as follows:
using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;
The input is an object of class Operator defined in the original framework.
- The callback function prototype of the registration function FusionParseParamsFn (Overload) for the scope fusion operators is as follows:
using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
The input is the NodeDef of all operators in the scope. The output is the data structure of the fused operator, which stores the fused operator information.
You can define and implement the callback function to map the attributes of the small and medium-sized operators in the original model to the attributes of the fused operator, and fill the result in the operator class.
Status FusionParseParamByOpFunc(const std::vector<ge::Operator> &op_src, ge::Operator &op_dest);
The following is an example of the implementation code of FusionParseParamsFn, which is used to find the scales attribute of the fused operator from the small operators in the original model.
Status ParseFloatFromConstNode(const ge::Operator *node, float &value) {
if (node == nullptr) {
return FAILED;
}
ge::Tensor tensor;
auto ret = node->GetAttr("value", tensor);
if (ret != ge::GRAPH_SUCCESS) {
AscendString op_name;
ret = node->GetName(op_name);
if (ret != ge::GRAPH_SUCCESS) {
return FAILED;
}
OP_LOGE(op_name.GetString(), "Failed to get value from %s", op_name.GetString());
return FAILED;
}
uint8_t *data_addr = tensor.GetData();
value = *(reinterpret_cast<float *>(data_addr));
return SUCCESS;
}
// Define a callback function.
Status DecodeBboxV2ParseParams(const std::vector<ge::Operator> &inside_nodes, ge::Operator &op_dest) {
std::map<std::string, std::string> scales_const_name_map;
std::map<string, const ge::Operator *> node_map;
for (const auto &node : inside_nodes) {
ge::AscendString op_type;
ge::graphStatus ret = node.GetOpType(op_type);
if (ret != ge::GRAPH_SUCCESS) {
return FAILED;
}
ge::AscendString op_name;
ret = node.GetName(op_name);
string str_op_name;
if (op_name.GetString() != nullptr) {
str_op_name = op_name.GetString();
}
if (op_type == kBoxesDiv) {
if (node.GetInputsSize() < kRealDivInputSize) {
OP_LOGE(op_name.GetString(), "Input size of %s is invalid, which is %zu.", kBoxesDiv, node.GetInputsSize());
return FAILED;
}
ge::AscendString input_unpack_name0;
ret = node.GetInputDesc(0).GetName(input_unpack_name0);
string str_input_unpack_name0;
if (input_unpack_name0.GetString() != nullptr) {
str_input_unpack_name0 = input_unpack_name0.GetString();
}
ge::AscendString input_unpack_name1;
ret = node.GetInputDesc(1).GetName(input_unpack_name1);
string str_input_unpack_name1;
if (input_unpack_name1.GetString() != nullptr) {
str_input_unpack_name1 = input_unpack_name1.GetString();
}
if (str_input_unpack_name0.find(kBoxesUnpack) != string::npos) {
scales_const_name_map.insert({str_op_name, str_input_unpack_name1 });
}
}
node_map[str_op_name] = &node;
}
std::vector<float> scales_list = {1.0, 1.0, 1.0, 1.0};
if (scales_const_name_map.size() != kScaleSize) {
ge::AscendString op_name;
ge::graphStatus ret = op_dest.GetName(op_name);
if (ret != ge::GRAPH_SUCCESS) {
return FAILED;
}
OP_LOGI(op_name.GetString(), "Boxes doesn't need scale.");
} else {
size_t i = 0;
for (const auto &name_pair : scales_const_name_map) {
float scale_value = 1.0;
auto ret = ParseFloatFromConstNode(node_map[name_pair.second], scale_value);
if (ret != SUCCESS) {
return ret;
}
scales_list[i++] = scale_value;
}
}
op_dest.SetAttr("scales", scales_list);
return SUCCESS;
}