昇腾社区首页
中文
注册

定义融合规则

在融合规则实现文件(例如decode_bbox_v2_scope_fusion_pass.cc)中,定义Scope的融合规则,便于系统基于此规则在全图的Scope中匹配,找到符合相应规则的Scope。

主要过程为:

  1. 包含头文件。
    #include "decode_bbox_v2_scope_fusion_pass.h"
  2. 定义融合规则中使用到的常量。
    namespace ge {
        namespace {
            const char *const kScopeType = "DecodeBboxV2FusionOp";      // 定义融合结果类型
            const char *const kScopeTypeDecodeBboxV2 = "DecodeBboxV2";  // 定义目标Scope类型
            const char *const kOpType = "DecodeBboxV2";                 // 用于日志打印
        }  // namespace
    ...
    }  // namespace ge

    融合规则实现文件中定义的其他所有函数均在namespace ge命名空间内。

  1. 通过DefinePatterns自定义融合规则,用于在全图Scope中匹配到符合相应规则的Scope,目前支持以下三种模式:
    • 基于scope中某一类型算子的个数或者个数的倍数匹配,关键接口为NodeOpTypeFeature
      例如要求Scope内含有2个Exp算子:
      decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0));        // Exp num is 2
    • 基于scope中某一类型算子的某一属性的值匹配,关键接口为NodeAttrFeature
      例如要求Split算子的num_split属性值必须为4:
      basic_lstm_cell_sigmoid->AddNodeAttrFeature(NodeAttrFeature("Split", "num_split", ge::DT_INT32, 4));
    • 基于scope自身或者其子scope的特征匹配,关键接口为ScopeFeature。示例:

      根据Scope的名称匹配,例如要求Scope的LastName为while。

      p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "while"));

      根据其子Scope的类型和个数匹配,例如要求Scope内包含1个子scope,子scope类型为kLstmCellReluType。

      p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, ""));

      根据其子Scope的名称匹配,例如要求子scope的LastName中包含字符串“while”。

      p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "", "while"));

      将以上特征结合起来设定匹配规则,例如要求Scope的LastName为while,且包含1个子scope,子scope类型为kLstmCellReluType。

      p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, "while"));

    下面示例中,通过NodeOpTypeFeature定义融合规则,要求Scope内包括2个Exp算子/4个Mul算子/4个Sub算子/2的倍数个RealDiv算子/2个Unpack算子/1个Pack算子/3个Transpose算子/不能包括Softmax算子。

    // DecodeBboxV2ScopeFusionPass子类实现自定义融合规则
        std::vector <ScopeFusionPatterns> DecodeBboxV2ScopeFusionPass::DefinePatterns() {
            std::vector <ScopeFusionPatterns> patterns_list;
            ScopeFusionPatterns pattern;
            GenScopePatterns(pattern);
            patterns_list.push_back(pattern);
            return patterns_list;
        }
    // 根据Scope关键特征,编写融合规则
        void DecodeBboxV2ScopeFusionPass::GenScopePatterns(ScopeFusionPatterns &patterns) {
            std::vector < ScopePattern * > batch;
            ScopePattern *decode_bbox_v2_pattern = new(std::nothrow) ScopePattern();
            if (decode_bbox_v2_pattern == nullptr) {
                OP_LOGE(kOpType, "Alloc an object failed.");
                return;
            }
            decode_bbox_v2_pattern->SetSubType(kScopeTypeDecodeBboxV2);                          // 设置Scope类型
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0));        // Exp num is 2
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Mul", 4, 0));        // Mul num is 4
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Sub", 4, 0));        // Sub num is 4
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("RealDiv", 0, 2));    // RealDiv num is 2*n
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Unpack", 2, 0));     // Unpack num is 2
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Pack", 1, 0));       // Pack num is 1
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Transpose", 3, 0));  // Transpose num is 3
            decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Softmax", -1, 0));   // doesn't have Softmax
    
            OP_LOGI(kOpType, "Add GenScopePatterns DecodeBboxV2.");
            batch.push_back(decode_bbox_v2_pattern);
            patterns.push_back(batch);
        }

    对于比较复杂的Scope,可以结合使用上面三种匹配方式,来定义Scope融合规则。