定义融合规则
在融合规则实现文件(例如decode_bbox_v2_scope_fusion_pass.cc)中,定义Scope的融合规则,便于系统基于此规则在全图的Scope中匹配,找到符合相应规则的Scope。
主要过程为:
- 包含头文件。
#include "decode_bbox_v2_scope_fusion_pass.h"
- 定义融合规则中使用到的常量。
namespace ge { namespace { const char *const kScopeType = "DecodeBboxV2FusionOp"; // 定义融合结果类型 const char *const kScopeTypeDecodeBboxV2 = "DecodeBboxV2"; // 定义目标Scope类型 const char *const kOpType = "DecodeBboxV2"; // 用于日志打印 } // namespace ... } // namespace ge
融合规则实现文件中定义的其他所有函数均在namespace ge命名空间内。
- 通过DefinePatterns自定义融合规则,用于在全图Scope中匹配到符合相应规则的Scope,目前支持以下三种模式:
- 基于scope中某一类型算子的个数或者个数的倍数匹配,关键接口为NodeOpTypeFeature。
例如要求Scope内含有2个Exp算子:
decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0)); // Exp num is 2
- 基于scope中某一类型算子的某一属性的值匹配,关键接口为NodeAttrFeature。
例如要求Split算子的num_split属性值必须为4:
basic_lstm_cell_sigmoid->AddNodeAttrFeature(NodeAttrFeature("Split", "num_split", ge::DT_INT32, 4));
- 基于scope自身或者其子scope的特征匹配,关键接口为ScopeFeature。示例:
根据Scope的名称匹配,例如要求Scope的LastName为while。
p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "while"));
根据其子Scope的类型和个数匹配,例如要求Scope内包含1个子scope,子scope类型为kLstmCellReluType。
p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, ""));
根据其子Scope的名称匹配,例如要求子scope的LastName中包含字符串“while”。
p_lstm_relu_while->AddScopeFeature(ScopeFeature("", 0, "", "while"));
将以上特征结合起来设定匹配规则,例如要求Scope的LastName为while,且包含1个子scope,子scope类型为kLstmCellReluType。
p_lstm_relu_while->AddScopeFeature(ScopeFeature(kLstmCellReluType, 1, "while"));
下面示例中,通过NodeOpTypeFeature定义融合规则,要求Scope内包括2个Exp算子/4个Mul算子/4个Sub算子/2的倍数个RealDiv算子/2个Unpack算子/1个Pack算子/3个Transpose算子/不能包括Softmax算子。
// DecodeBboxV2ScopeFusionPass子类实现自定义融合规则 std::vector <ScopeFusionPatterns> DecodeBboxV2ScopeFusionPass::DefinePatterns() { std::vector <ScopeFusionPatterns> patterns_list; ScopeFusionPatterns pattern; GenScopePatterns(pattern); patterns_list.push_back(pattern); return patterns_list; } // 根据Scope关键特征,编写融合规则 void DecodeBboxV2ScopeFusionPass::GenScopePatterns(ScopeFusionPatterns &patterns) { std::vector < ScopePattern * > batch; ScopePattern *decode_bbox_v2_pattern = new(std::nothrow) ScopePattern(); if (decode_bbox_v2_pattern == nullptr) { OP_LOGE(kOpType, "Alloc an object failed."); return; } decode_bbox_v2_pattern->SetSubType(kScopeTypeDecodeBboxV2); // 设置Scope类型 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Exp", 2, 0)); // Exp num is 2 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Mul", 4, 0)); // Mul num is 4 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Sub", 4, 0)); // Sub num is 4 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("RealDiv", 0, 2)); // RealDiv num is 2*n decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Unpack", 2, 0)); // Unpack num is 2 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Pack", 1, 0)); // Pack num is 1 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Transpose", 3, 0)); // Transpose num is 3 decode_bbox_v2_pattern->AddNodeOpTypeFeature(NodeOpTypeFeature("Softmax", -1, 0)); // doesn't have Softmax OP_LOGI(kOpType, "Add GenScopePatterns DecodeBboxV2."); batch.push_back(decode_bbox_v2_pattern); patterns.push_back(batch); }
对于比较复杂的Scope,可以结合使用上面三种匹配方式,来定义Scope融合规则。
- 基于scope中某一类型算子的个数或者个数的倍数匹配,关键接口为NodeOpTypeFeature。
父主题: Scope融合规则实现