定义融合规则类
在融合规则头文件(例如decode_bbox_v2_scope_fusion_pass.h)中,定义Scope融合规则类, 继承自ScopeBasePass类。
#ifndef FRAMEWORK_TF_SCOPE_FUSION_PASS_DECODE_BBOX_V2_PASS_H_ // 条件编译
#define FRAMEWORK_TF_SCOPE_FUSION_PASS_DECODE_BBOX_V2_PASS_H_ // 宏定义
#include <string>
#include <vector>
#include "register/scope/scope_fusion_pass_register.h"
namespace ge {
class DecodeBboxV2ScopeFusionPass : public ScopeBasePass {
protected:
std::vector<ScopeFusionPatterns> DefinePatterns() override;
std::string PassName() override;
Status LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph> &scope_graph, std::vector<ScopesResult> &results) override;
void GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) override;
private:
void GenScopePatterns(ScopeFusionPatterns &patterns);
};
} // namespace ge
#endif // FRAMEWORK_TF_SCOPE_FUSION_PASS_DECODE_BBOX_V2_PASS_H_ 结束条件编译
父主题: Scope融合规则实现