使用自定义Pass修改Graph
改图功能主要有两种调用方式,用户可以在构建完Graph后,直接调用改图接口修改Graph,也可以将改图函数封装为自定义Pass,并通过REGISTER_CUSTOM_PASS注册宏进行改图Pass注册,通过把改图函数编译成动态库插件方式,注册的Pass可以在指定的阶段被调用。
功能介绍
1 2 3 4 5 6 7 8 |
#include "register_custom_pass.h" // 用户自定义改图函数 Status CustomPassFunc(GraphPtr &graph, CustomPassContext &custom_context) { // 此处定义图修改具体行为 return GRAPH_SUCCESS; } // 改图Pass注册,注册后的Pass在指定阶段被调用 REGISTER_CUSTOM_PASS("pass_name").CustomPassFn(CustomPassFunc).Stage(CustomPassStage::kBeforeInferShape); |
- register_custom_pass.h:存储在“CANN软件安装目录/latest/include/register/”目录下,包含该头文件,可使用Pass注册相关类,使用Pass注册相关接口。
- Status:成功返回ge::GRAPH_SUCCESS,返回其他全为失败。建议使用小于0的值作为返回的错误码,大于0的值可能会和框架已使用的错误码产生冲突。
- CustomPassFunc:自定义Pass的入口函数,详情请参见回调函数CustomPassFunc。
- graph:自定义Pass要修改的图,类型为GraphPtr。
- custom_context:CustomPassContext类对象,可参考CustomPassContext提供的方法。
- REGISTER_CUSTOM_PASS:注册自定义Pass,"pass_name"可任意命名,详情请参见REGISTER_CUSTOM_PASS。
- CustomPassFn:自定义Pass的执行函数,详情请参见CustomPassFn。
- Stage:指定Pass执行阶段,详情请参见Stage。
开发示例
此处以MatMul+Add融合为GEMM自定义Pass为例,详细介绍如何通过自定义Pass修改Graph,详细可以参见样例源码。
样例仓还提供了“如何通过自定义Pass把Tile+Concat图结构修改为Concat+Tile+Concat图结构”的详细说明,可以从完整样例参考获取链接。
- 包含的头文件。
1 2 3 4 5 6 7
#include <iostream> // 自定义Pass接口头文件 #include "register_custom_pass.h" // 新增算子头文件 #include "all_ops.h" // 如果使用Ascend C自定义了算子,需要包含如下头文件: #include "CANN软件安装目录/latest/opp/vendors/customize/op_proto/inc/op_proto.h"
其中:CANN软件安装目录:请修改为CANN软件包的实际安装路径。customize:请修改为Ascend C自定义算子实际工程名。
- 使用自定义Pass修改Graph。(如下代码仅为示例,不可执行;修改Graph时,只能使用CustomPassContext、Graph~GNode、PassReceiver、PassRegistrationData、REGISTER_CUSTOM_PASS、StreamPassContext中的接口)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
namespace { constexpr const char *kOpNameAdd = "add"; constexpr const char *kOpNameMatMul = "matmul"; constexpr const char *kOpNameGEMM = "gemm"; constexpr const char *kOpNameAlpha = "alpha"; constexpr const char *kOpNameBeta = "beta"; constexpr const char *kAttrNameTransposeA = "transpose_a"; constexpr const char *kAttrNameTransposeB = "transpose_b"; constexpr int32_t kIndex0 = 0; constexpr int32_t kIndex1 = 1; constexpr int32_t kIndex2 = 2; constexpr int32_t kIndex3 = 3; constexpr int32_t kIndex4 = 4; // 1.遍历所有节点,寻找MatMul和Add节点 bool FindNodes(GraphPtr &graph, GNode &src_node, GNode &dst_node) { auto all_nodes = graph->GetAllNodes(); bool find_src_node = false; bool find_dst_node = false; for (auto &node: all_nodes) { AscendString node_name; auto ret = node.GetName(node_name); if (node_name == kOpNameMatMul) { src_node = node; find_src_node = true; cout << "Find src node: MatMul." << endl; } else if (node_name == kOpNameAdd) { dst_node = node; find_dst_node = true; cout << "Find dst node: Add." << endl; } } return (find_src_node && find_dst_node); } // 2.判断MatMul和Add节点是否有连边关系 bool CheckNodesHaveEdge(GraphPtr &graph, const GNode &src_node, const GNode &dst_node) { for (auto &[out_node, _]: src_node.GetOutDataNodesAndPortIndexs(kIndex0)) { AscendString node_name; auto ret = out_node->GetName(node_name); if (node_name == kOpNameAdd) { return true; } } return false; } // 3.创建和添加GEMM节点 void CreateGEMMNode(GraphPtr &graph, const GNode &src_node, GNode &node_gemm) { bool transpose_a = false; bool transpose_b = false; src_node.GetAttr(kAttrNameTransposeA, transpose_a); src_node.GetAttr(kAttrNameTransposeB, transpose_b); constexpr float kValue1 = 1; TensorDesc alpha_desc(ge::Shape({1}), FORMAT_ND, DT_FLOAT); Tensor alpha_tensor(alpha_desc, reinterpret_cast<const uint8_t *>(&kValue1), sizeof(float)); auto alpha = op::Const(kOpNameAlpha).set_attr_value(alpha_tensor); TensorDesc beta_desc(ge::Shape({1}), FORMAT_ND, DT_FLOAT); Tensor beta_tensor(beta_desc, reinterpret_cast<const uint8_t *>(&kValue1), sizeof(float)); auto beta = op::Const(kOpNameBeta).set_attr_value(beta_tensor); auto gemm = op::GEMM(kOpNameGEMM); gemm.set_attr_transpose_a(transpose_a) .set_attr_transpose_b(transpose_b); gemm.update_input_desc_alpha(alpha_desc); gemm.update_input_desc_beta(beta_desc); auto node_alpha = graph->AddNodeByOp(alpha); auto node_beta = graph->AddNodeByOp(beta); node_gemm = graph->AddNodeByOp(gemm); auto ret = graph->AddDataEdge(node_alpha, kIndex0, node_gemm, kIndex3); ret = graph->AddDataEdge(node_beta, kIndex0, node_gemm, kIndex4); } // 4.添加新节点的输入输出 bool AddInputsAndOutputs(GraphPtr &graph, const GNode &src_node, const GNode &dst_node, GNode &node_gemm) { auto [a, a_output_index] = src_node.GetInDataNodesAndPortIndexs(kIndex0); auto [b, b_output_index] = src_node.GetInDataNodesAndPortIndexs(kIndex1); int32_t add_node_c_input_index = -1; for (size_t i = 0; i < dst_node.GetInputsSize(); ++i) { auto [in_node, _] = dst_node.GetInDataNodesAndPortIndexs(i); AscendString node_name; auto ret = in_node->GetName(node_name); if (node_name != kOpNameMatMul) { add_node_c_input_index = i; break; } } if (add_node_c_input_index == -1) { return false; } auto [c, c_output_index] = dst_node.GetInDataNodesAndPortIndexs(add_node_c_input_index); auto ret = graph->AddDataEdge(*a, a_output_index, node_gemm, kIndex0); if (ret != GRAPH_SUCCESS) { return false; } ret = graph->AddDataEdge(*b, b_output_index, node_gemm, kIndex1); ret = graph->AddDataEdge(*c, c_output_index, node_gemm, kIndex2); TensorDesc input_desc_a; ret = src_node.GetInputDesc(kIndex0, input_desc_a); ret = node_gemm.UpdateInputDesc(kIndex0, input_desc_a); TensorDesc input_desc_b; ret = src_node.GetInputDesc(kIndex1, input_desc_b); ret = node_gemm.UpdateInputDesc(kIndex1, input_desc_b); TensorDesc input_desc_c; ret = dst_node.GetInputDesc(add_node_c_input_index, input_desc_c); ret = node_gemm.UpdateInputDesc(kIndex2, input_desc_c); TensorDesc output_desc_y; ret = dst_node.GetOutputDesc(kIndex0, output_desc_y); ret = node_gemm.UpdateOutputDesc(kIndex0, output_desc_y); return true; } // 5.删除旧节点和其连边关系,连接新GEMM节点和输出节点 void RemoveOldNodesEdgesAndAddGemmOutput(GraphPtr &graph, GNode &src_node, GNode &dst_node, GNode &node_gemm) { vector<GNode> node_vec{src_node, dst_node}; for (auto &node: node_vec) { for (size_t i = 0; i < node.GetInputsSize(); ++i) { auto [in_node, in_id] = node.GetInDataNodesAndPortIndexs(i); if (in_node != nullptr) { auto ret = graph->RemoveEdge(*in_node, in_id, node, i); } } } for (auto &[out_node, out_id]: dst_node.GetOutDataNodesAndPortIndexs(kIndex0)) { if (out_node != nullptr) { auto ret = graph->RemoveEdge(dst_node, kIndex0, *out_node, out_id); ret = graph->AddDataEdge(node_gemm, kIndex0, *out_node, out_id); } } for (auto &node: node_vec) { auto ret = graph->RemoveNode(node); } } } // namespace // |o>----------------------------------- // |o> a b // |o> \ / a b c // |o> MatMul c ==> \ | / // |o> \ / GEMM // |o> Add // |o>----------------------------------- // 融合说明:本例识别上图中左边的MatMul+Add结构并通过图修改接口替换为右边的单个GEMM节点 // 改图接口返回值说明:本文件中的改图接口需要判断返回值, 基于可读性考虑除了Pass入口函数外,其他函数中的改图接口只接收返回值,但不增加返回值处理代码;如需判断返回值,可配合使用custom_context.SetErrorMessage("xxx")方法 graphStatus FuseMatMulAndAddPass(GraphPtr &graph, CustomPassContext &custom_context) { cout << "FuseMatMulAndAddPass begin." << endl; GNode src_node; GNode dst_node; // 1.遍历所有节点,寻找MatMul和Add节点 if (!FindNodes(graph, src_node, dst_node)) { cout << "Do not find MatMul or Add node." << endl; return GRAPH_SUCCESS; } // 2.判断MatMul和Add节点是否有连边关系 if (!CheckNodesHaveEdge(graph, src_node, dst_node)) { cout << "There is no edge between src and dst node." << endl; return GRAPH_SUCCESS; } // 3.创建和添加GEMM节点 GNode node_gemm; CreateGEMMNode(graph, src_node, node_gemm); // 4.添加新节点的输入输出 if (!AddInputsAndOutputs(graph, src_node, dst_node, node_gemm)) { custom_context.SetErrorMessage("Add inputs and outputs failed."); return -1; } // 5.删除旧节点和其连边关系,连接新GEMM节点和输出节点 RemoveOldNodesEdgesAndAddGemmOutput(graph, src_node, dst_node, node_gemm); cout << "FuseMatMulAndAddPass end." << endl; return GRAPH_SUCCESS; } REGISTER_CUSTOM_PASS("FuseMatMulAndAddPass").CustomPassFn(FuseMatMulAndAddPass).Stage(CustomPassStage::kBeforeInferShape);
如何使用自定义Pass
完成上述自定义Pass后,本节简单介绍如何把改图函数编译成动态库插件方式,以便注册的Pass在图编译的最开始被框架调用。详细使用说明请参见样例使用指导。
- 把开发示例中的改图函数编译成仅以".so"结尾的动态库文件。
- 把上述".so"动态库文件复制到${INSTALL_DIR}/opp/vendors/xxx/custom_fusion_passes/目录下。(支持设置软链接的方式;".so"文件对可执行用户,需要有可读权限)
多个"${INSTALL_DIR}/opp/vendors/xxx"目录按照文本序排序后遍历寻找"custom_fusion_passes/"子目录,单个子目录内的".so"按照文本序加载,非".so"结尾的文件在加载时跳过。
- ${INSTALL_DIR}请替换为CANN软件安装后文件存储路径。若安装的Ascend-cann-toolkit软件包,以root安装举例,则安装后文件存储路径为:/usr/local/Ascend/ascend-toolkit/latest。
- xxx:有且仅有一层自定义目录。
- custom_fusion_passes:该目录下不能有子目录。
- 支持但不限于如下几种入口编译模型文件:
如果要查看上述自定义Pass有没有生效,在编译模型前,需要dump图进行查看:在执行之前设置DUMP_GE_GRAPH环境变量,然后使用如下入口编译模型:
- 使用ATC工具进行模型转换。ATC工具使用方法请参见《ATC离线模型编译工具用户指南》。
- 编译Graph为离线模型。
- 编译并运行Graph。
结果验证
设置了dump环境变量后,程序执行完毕,会在当前路径生成ge_onnx*.pbtxt等图文件,用户可以获取如下两张图,然后使用Netron等可视化软件查看:
- 指定Pass执行阶段在InferShape之前:
- ge_onnx_xxxxxxxx_RunCustomPassBegin.pbtxt:融合前的图
- ge_onnx_xxxxxxxx_RunCustomPassEnd.pbtxt:融合后的图
以MatMul+Add融合为GEMM自定义Pass为例,查看融合前的图结构为:
通过自定义Pass修改后的图结构如下所示,可以看出MatMul+Add结构已经替换为单个GEMM节点。
- 指定Pass执行阶段在InferShape之后:
- ge_onnx_xxxx_PrepareAfterInferFormatAndShape.pbtxt:融合前的图
- ge_onnx_xxxx_RunCustomPass_AfterInferShape.pbtxt:融合后的图
以BatchMatMulV2融合为Transpose+Mul+ReduceSum自定义Pass为例,查看融合前的图结构为:
通过自定义Pass修改后的图结构如下所示,可以看出BatchMatMulV2已经替换为Transpose+Mul+ReduceSum三个节点。