Defining a Fusion Pattern

In the fusion pattern implementation file (for example, decode_bbox_v2_scope_fusion_pass.cc), define the scope fusion patterns for the system to look for the matching scopes.

The major steps are described as follows:

  1. Include the header file.
    #include "decode_bbox_v2_scope_fusion_pass.h"
  2. Define constants used in the fusion pattern.
    namespace ge {
        namespace {
            const char *const kScopeType = "DecodeBboxV2FusionOp";      // Define the fusion result type.
            const char *const kScopeTypeDecodeBboxV2 = "DecodeBboxV2";  // Define the target scope type.
            const char *const kOpType = "DecodeBboxV2";                 // Print logs.
        }  // namespace
    ...
    }  // namespace ge

    All other functions defined in the fusion pattern implementation file are in the namespace ge namespace.

  1. Define a fusion pattern by referring to DefinePatterns. This pattern is used to match the scope that meets the corresponding pattern in the scope graph. Currently, the following three APIs are supported:
    • NodeOpTypeFeature: Scopes are matched based on the number of operators or the multiple of a number of operators of a certain type in the scope.
      For example, the scope must contain two Exp operators.
      decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0));        // Exp num is 2
    • NodeAttrFeature: Scopes are matched based on the value of an attribute of an operator of a certain type in the scope.
      For example, the num_split attribute of the Split operator must be set to 4.
      basic_lstm_cell_sigmoid->AddNodeAttrFeature(NodeAttrFeature("Split", "num_split", ge::DT_INT32, 4));
    • ScopeFeature: Scopes are matched based on the features of the scope or its sub-scopes. Example:

      For example, match the scope based on the scope name. For example, set LastName of the scope to while.

      p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "while"));

      Match the scope based on the type and number of sub-scopes. For example, the scope must contain one sub-scope, and the sub-scope type is kLstmCellReluType.

      p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, ""));

      Match the scope based on the name of the sub-scope. For example, the LastName of the sub-scope must contain the string while.

      p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "", "while"));

      A pattern is set based on the preceding features. For example, the LastName of the scope is while, the scope contains one sub-scope, and the sub-scope type is kLstmCellReluType.

      p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, "while"));

    In the following example, NodeOpTypeFeature is used to define the fusion pattern. The scope must contain two Exp operators, four Mul operators, four Sub operators, a multiple of two RealDiv operators, two Unpack operators, one Pack operator, and three Transpose operators and cannot contain Softmax operator.

    // Implement the custom fusion pattern using the DecodeBboxV2ScopeFusionPass subclass.
        std::vector <ScopeFusionPatterns> DecodeBboxV2ScopeFusionPass::DefinePatterns() {
            std::vector <ScopeFusionPatterns> patterns_list;
            ScopeFusionPatterns pattern;
            GenScopePatterns(pattern);
            patterns_list.push_back(pattern);
            return patterns_list;
        }
    // Write a fusion pattern based on the key features of the scope.
        void DecodeBboxV2ScopeFusionPass::GenScopePatterns(ScopeFusionPatterns &patterns) {
            std::vector < ScopePattern * > batch;
            ScopePattern *decode_bbox_v2_pattern = new(std::nothrow) ScopePattern();
            if (decode_bbox_v2_pattern == nullptr) {
                OP_LOGE(kOpType, "Alloc an object failed.");
                return;
            }
            decode_bbox_v2_pattern->SetSubType(kScopeTypeDecodeBboxV2);                          // Set the scope type.
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0));        // Exp num is 2
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Mul", 4, 0));        // Mul num is 4
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Sub", 4, 0));        // Sub num is 4
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("RealDiv", 0, 2));    // RealDiv num is 2*n
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Unpack", 2, 0));     // Unpack num is 2
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Pack", 1, 0));       // Pack num is 1
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Transpose", 3, 0));  // Transpose num is 3
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Softmax", -1, 0));   // doesn't have Softmax
    
            OP_LOGI(kOpType, "Add GenScopePatterns DecodeBboxV2.");
            batch.push_back(decode_bbox_v2_pattern);
            patterns.push_back(batch);
        }

    For a complex scope, you can combine the preceding three matching methods to define a scope fusion pattern.