Setting the Fusion Result
Set the fusion result by referring to GenerateFusionResult, including the name, type, input, output, and description of the fused operator. The final fusion result is stored in fusion_rlt. For the types of the returned results, see Class FusionScopesResult.
- Set the input of the fused operator by referring to InsertInputs. For example:
fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});- The first parameter indicates the input of the fused operator, that is, the name of the small operator in the scope (excluding the scope name).
- The second parameter indicates the mapping of the input index, which is of the vector type. The index of the vector indicates the input index of the small operator in the scope. Its value indicates the input index of the fused operator. If the fused operator does not use the index, the placeholder kFusionDisableIndex is used.
Table 1 Samples No.
Sample Code
Sample 1
fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});Indicates that input 0 of transpose is used as input 0 of the fused operator. Input 1 of transpose is not used and is represented by the placeholder kFusionDisableIndex.
Sample 2
fusion_rlt->InsertInputs("transpose", {1, kFusionDisableIndex});Indicates that input 0 of transpose is used as input 1 of the fused operator. Input 1 of transpose is not used and is represented by the placeholder kFusionDisableIndex.
Sample 3
fusion_rlt->InsertInputs("transpose", {kFusionDisableIndex, 0});Indicates that input 1 of transpose is used as input 0 of the fused operator. Input 0 of transpose is not used and is represented by the placeholder kFusionDisableIndex.
- Set the output of the fused operator by referring to InsertOutputs. The usage precautions are similar to those for setting the input of the fused operator. For example:
// Set the output of the fused operator. Use output 0 of transpose_1 as the output of the fused operator. fusion_rlt->InsertOutputs("transpose_1", {0}); - Set the result type of the fused operator by referring to SetType. For example:
// Set the type of the fused operator. fusion_rlt->SetType(kScopeType);
Note that the input type must be the same as the value of OriginOpType registered for the fused operator plugin.
REGISTER_CUSTOM_OP("DecodeBboxV2") .FrameworkType(TENSORFLOW) // The original framework is TensorFlow. .OriginOpType("DecodeBboxV2FusionOp") // Type of the operator in the original framework, which is the same as the value of SetType in GenerateFusionResult. .FusionParseParamsFn(DecodeBboxV2ParseParams) // Used to register the function for parsing the attributes of a fused operator. .ImplyType(ImplyType::TVM); // Specifies the implementation mode of an operator. ImplyType::TVM indicates that the operator is a TBE operator.If the scope does not meet the requirements, no fusion will be performed. You can set type to kScopeInvalidType.
if (scopes.size() != 1) { fusion_rlt->SetType(kScopeInvalidType); return; } - Set the name of the fused operator by referring to SetName. To ensure that the name of the fused operator is globally unique, you are advised not to name the fused operator by yourself. You can set the name based on the scope name. For example:
// Set the name of the fused operator. AscendString scope_name; Status ret = scopes[0]->Name(scope_name); std::string scope_name_str; if (scope_name.GetString() != nullptr) { scope_name_str = scope_name.GetString(); } fusion_rlt->SetName(scope_name_str .substr(0, scope_name.length() - 1).c_str()); - Set the description of the fused operator by referring to SetDescription. For example:
// Set the description of the fused operator. fusion_rlt->SetDescription("");
void CustomScopeDecodeBboxV2Pass::GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) {
if (fusion_rlt == nullptr) {
return;
}
if (scopes.size() != 1) {
fusion_rlt->SetType(kScopeInvalidType); // If the scope does not meet the requirements, set type to kScopeInvalidType.
return;
}
// Set the input of the fused operator. Use input 0 of the transpose operator as input 0 of the fused operator. Input 1 of the transpose operator is not used.
fusion_rlt->InsertInputs("transpose", {0, kFusionDisableIndex});
//Set the input of the fused operator. Use input 0 of get_center_coordinates_and_sizes/transpose as the first input of the fused operator. Input 1 of get_center_coordinates_and_sizes/transpose is not used.
fusion_rlt->InsertInputs("get_center_coordinates_and_sizes/transpose", {1, kFusionDisableIndex});
// Set the output of the fused operator. Use output 0 of transpose_1 as the output of the fused operator.
fusion_rlt->InsertOutputs("transpose_1", {0});
// Set the type of the fused operator.
fusion_rlt->SetType(kScopeType);
// Set the name of the fused operator.
AscendString scope_name;
Status ret = scopes[0]->Name(scope_name);
if (ret != SUCCESS) {
return;
}
std::string scope_name_str;
if (scope_name.GetString() != nullptr) {
scope_name_str = scope_name.GetString();
}
fusion_rlt->SetName(scope_name_str .substr(0, scope_name.length() - 1).c_str());
// Set the description of the fused operator.
fusion_rlt->SetDescription("");
OP_LOGI(kOpType, "Set fusion result successfully.");
return;
}