Setting the Final Pattern
This includes setting basic scope patterns and parallel scope patterns.
Overview
A scope that meets the requirements in Defining a Fusion Pattern may not be the final scope. For example, you need to filter and determine the parallel scopes and nested scopes, as shown in Figure 1. In this case, you can set the final matching pattern by referring to LastMatchScopesAndOPs to narrow down the matched scopes and save the matched scopes to the ScopesResult.
Basic Scope Pattern
In this example, the expected type of the scope to be fused is kScopeTypeDecodeBboxV2. You do not need to set advanced patterns. Find the scope and save it to the results. For the types of the returned results, see Class ScopesResult.
Status DecodeBboxV2ScopeFusionPass::LastMatchScopesAndOPs(shared_ptr <ScopeGraph> &scope_graph,
std::vector <ScopesResult> &results) {
OP_LOGI(kOpType, "LastMatchScopesAndOPs start.");
if (scope_graph == nullptr) {
OP_LOGE(kOpType, "Input params is nullptr.");
return FAILED;
}
const ScopeTree *scope_tree = scope_graph->GetScopeTree();
if (scope_tree == nullptr) {
OP_LOGE(kOpType, "Scope tree is nullptr.");
return FAILED;
}
const std::vector<Scope *> &scopes = scope_tree->GetAllScopes();
for (auto &scope : scopes) {
// Class ScopeTree guarantees scope is not empty.
AscendString op_subtype;
Status ret = scope->SubType(op_subtype);
if (ret != SUCCESS) {
return FAILED;
}
AscendString op_name;
ret = scope->Name(op_name);
if (ret != SUCCESS) {
return FAILED;
}
if (op_subtype == kScopeTypeDecodeBboxV2) {
OP_LOGI(kOpType, "DecodeBbox LastMatchScopesAndOPs match scope %s.", op_name.GetString());
ScopesResult result;
std::vector < Scope * > result_scopes;
result_scopes.push_back(scope);
result.SetScopes(result_scopes);
results.push_back(result);
}
}
return (!(results.empty())) ? SUCCESS : FAILED;
}
Parallel Scope Pattern
You can also define more complex parallel scope patterns. For example, find the scopes of the kScopeTypeBatchnorm and kScopeTypeMoments types, and determine whether the two scopes are at the same layer of the network. If they are at the same layer, perform fusion.
/**
* @brief LastMatch for multiple scopes
*/
Status ScopeLayerNormPass::LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph>& scope_graph,
std::vector<ScopesResult>& results) {
if (scope_graph == nullptr) {
OP_LOGE(kOpType, "Input params is nullptr.");
return domi::PARAM_INVALID;
}
const ScopeTree* scope_tree = scope_graph->GetScopeTree();
if (scope_tree == nullptr) {
OP_LOGE(kOpType, "Scope tree is nullptr.");
return domi::PARAM_INVALID;
}
const std::vector<Scope*>& scopes = scope_tree->GetAllScopes();
std::vector<Scope*> fusion_scopes_bn;
std::vector<Scope*> fusion_scopes_m;
for (auto& scope : scopes) {
// Class ScopeTree guarantees scope is not empty.
AscendString op_subtype;
Status ret = scope->SubType(op_subtype);
if (ret != SUCCESS) {
return FAILED;
}
if (op_subtype == kScopeTypeBatchnorm) {
fusion_scopes_bn.push_back(scope);
} else if (op_subtype == kScopeTypeMoments) {
fusion_scopes_m.push_back(scope);
}
}
if (fusion_scopes_bn.size() == fusion_scopes_m.size()) {
// the two scope batchnorm and moments in the same layernorm
for (size_t i = 0; i < fusion_scopes_bn.size(); i++) {
auto scope_bn = fusion_scopes_bn[i];
for (size_t j = 0; j < fusion_scopes_m.size(); j++) {
auto scope_m = fusion_scopes_m[j];
AscendString scope_bn_name;
Status ret = scope_bn->Name(scope_bn_name);
if (ret != SUCCESS) {
return FAILED;
}
AscendString scope_m_name;
ret = scope_m->Name(scope_m_name);
if (ret != SUCCESS) {
return FAILED;
}
std::string scope_m_name_str;
std::string scope_bn_name_str;
if (scope_m_name.GetString() != nullptr) {
scope_m_name_str = scope_m_name.GetString();
}
if (scope_bn_name.GetString() != nullptr) {
scope_bn_name_str = scope_bn_name.GetString();
}
int pos_bn = scope_bn_name_str .find("batchnorm");
int pos_m = scope_m_name_str .find("moments");
int is_biggan_bn = scope_bn_name_str .find("resblock");
int is_biggan_m = scope_m_name_str .find("resblock");
if (is_biggan_bn != -1 || is_biggan_m != -1) {
return FAILED;
}
if (pos_bn != -1 && pos_m != -1 && scope_bn_name_str.substr(0, pos_bn) == scope_m_name_str.substr(0, pos_m)) {
// scope result
ScopesResult result;
std::vector<Scope*> result_scopes;
result_scopes.push_back(scope_bn);
result_scopes.push_back(scope_m);
result.SetScopes(result_scopes);
results.push_back(result);
OP_LOGI(kOpType, "scope:%s, and scope:%s is connect.", scope_bn_name.GetString(), scope_m_name.GetString());
break;
}
}
}
}
return (!(results.empty())) ? SUCCESS : FAILED;
}
