Defining a Fusion Pattern
In the fusion pattern implementation file (for example, decode_bbox_v2_scope_fusion_pass.cc), define the scope fusion patterns for the system to look for the matching scopes.
The major steps are described as follows:
- Include the header file.
#include "decode_bbox_v2_scope_fusion_pass.h"
- Define constants used in the fusion pattern.
namespace ge { namespace { const char *const kScopeType = "DecodeBboxV2FusionOp"; // Define the fusion result type. const char *const kScopeTypeDecodeBboxV2 = "DecodeBboxV2"; // Define the target scope type. const char *const kOpType = "DecodeBboxV2"; // Print logs. } // namespace ... } // namespace geAll other functions defined in the fusion pattern implementation file are in the namespace ge namespace.
- Define a fusion pattern by referring to DefinePatterns. This pattern is used to match the scope that meets the corresponding pattern in the scope graph. Currently, the following three APIs are supported:
- NodeOpTypeFeature: Scopes are matched based on the number of operators or the multiple of a number of operators of a certain type in the scope.
For example, the scope must contain two Exp operators.
decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0)); // Exp num is 2 - NodeAttrFeature: Scopes are matched based on the value of an attribute of an operator of a certain type in the scope.
For example, the num_split attribute of the Split operator must be set to 4.
basic_lstm_cell_sigmoid->AddNodeAttrFeature(NodeAttrFeature("Split", "num_split", ge::DT_INT32, 4)); - ScopeFeature: Scopes are matched based on the features of the scope or its sub-scopes. Example:
For example, match the scope based on the scope name. For example, set LastName of the scope to while.
p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "while"));Match the scope based on the type and number of sub-scopes. For example, the scope must contain one sub-scope, and the sub-scope type is kLstmCellReluType.
p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, ""));
Match the scope based on the name of the sub-scope. For example, the LastName of the sub-scope must contain the string while.
p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "", "while"));A pattern is set based on the preceding features. For example, the LastName of the scope is while, the scope contains one sub-scope, and the sub-scope type is kLstmCellReluType.
p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, "while"));
In the following example, NodeOpTypeFeature is used to define the fusion pattern. The scope must contain two Exp operators, four Mul operators, four Sub operators, a multiple of two RealDiv operators, two Unpack operators, one Pack operator, and three Transpose operators and cannot contain Softmax operator.
// Implement the custom fusion pattern using the DecodeBboxV2ScopeFusionPass subclass. std::vector <ScopeFusionPatterns> DecodeBboxV2ScopeFusionPass::DefinePatterns() { std::vector <ScopeFusionPatterns> patterns_list; ScopeFusionPatterns pattern; GenScopePatterns(pattern); patterns_list.push_back(pattern); return patterns_list; } // Write a fusion pattern based on the key features of the scope. void DecodeBboxV2ScopeFusionPass::GenScopePatterns(ScopeFusionPatterns &patterns) { std::vector < ScopePattern * > batch; ScopePattern *decode_bbox_v2_pattern = new(std::nothrow) ScopePattern(); if (decode_bbox_v2_pattern == nullptr) { OP_LOGE(kOpType, "Alloc an object failed."); return; } decode_bbox_v2_pattern->SetSubType(kScopeTypeDecodeBboxV2); // Set the scope type. decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0)); // Exp num is 2 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Mul", 4, 0)); // Mul num is 4 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Sub", 4, 0)); // Sub num is 4 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("RealDiv", 0, 2)); // RealDiv num is 2*n decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Unpack", 2, 0)); // Unpack num is 2 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Pack", 1, 0)); // Pack num is 1 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Transpose", 3, 0)); // Transpose num is 3 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Softmax", -1, 0)); // doesn't have Softmax OP_LOGI(kOpType, "Add GenScopePatterns DecodeBboxV2."); batch.push_back(decode_bbox_v2_pattern); patterns.push_back(batch); }For a complex scope, you can combine the preceding three matching methods to define a scope fusion pattern.
- NodeOpTypeFeature: Scopes are matched based on the number of operators or the multiple of a number of operators of a certain type in the scope.