Implementing Pass Based on the Graph Modification API

This section describes how to modify a graph by using a custom pass encapsulated based on the graph modification API.

A graph can be modified in either of the following ways:

  • Graph modification API: You can directly modify the graph structure in real time as instructed in Modifying a Graph Using the Graph Modification API by calling the graph modification API. This is simple but you cannot apply another modification pass later.
  • Custom pass: You are advised to encapsulate the graph modification functions as a custom pass and register it using the REGISTER_CUSTOM_PASS macro. The registered pass can be compiled into a dynamic library plugin and called in a specified phase of the graph optimization process. It provides modular, reusable, and configurable graph transformation capabilities.

This section describes how to develop, register, and integrate a custom pass to optimize graphs efficiently and flexibly.

Overview

A graph modification function can be regarded as a custom pass. You can call the REGISTER_CUSTOM_PASS registration macro to register the pass based on the specified pass name. By compiling a graph modification function into a dynamic library plugin, the registered pass is called by the framework in a specified phase. The sample code is as follows:
1
2
3
4
5
6
7
8
#include "register_custom_pass.h"
// Define the custom graph modification function.
Status CustomPassFunc(GraphPtr &graph, CustomPassContext &custom_context) {
    // Define the graph modification behavior.
    return GRAPH_SUCCESS;
}
// Register the graph modification pass. The registered pass is called in a specified phase.
REGISTER_CUSTOM_PASS("pass_name").CustomPassFn(CustomPassFunc).Stage(CustomPassStage::kBeforeInferShape);
  • register_custom_pass.h: a header file stored in the /cann/include/register/ directory of the CANN installation directory. If this header file is included, you can use related classes and APIs for pass registration.
  • Status: operation status. If the operation is successful, ge::GRAPH_SUCCESS is returned. If the operation fails, other values are returned. You are advised to use a value less than 0 as the returned error code. A value greater than 0 may conflict with the error code used by the framework.
  • CustomPassFunc: entry point function of the custom pass. For details, see Callback Function CustomPassFunc.
  • graph: graph to be modified using a custom pass, which is of the GraphPtr type.
  • custom_context: object of class CustomPassContext. For details, see the methods provided in CustomPassContext.
  • REGISTER_CUSTOM_PASS: a macro used to register a custom pass. pass_name can be set to any name. For details, see REGISTER_CUSTOM_PASS.
  • CustomPassFn: execution function of the custom pass. For details, see CustomPassFn.
  • Stage: execution phase of the pass. For details, see Stage.

If you need to replace the operator with an operator of other functions during graph modification, but the operator is not supported by CANN, you can customize the operator in one of the following ways:

Use Ascend C to customize the operator. For details, see Ascend C Operator Development Guide.

After the operator is developed, you can use the custom pass to modify the graph.

Example

The following uses the example of fusing the MatMul+Add structure into the GEMM structure via a custom pass to describe how to modify a graph. For details, see sample source code.

The sample repository also provides a detailed description on how to modify the Tile+Concat graph structure to the Concat+Tile+Concat graph structure using a custom pass. You can obtain the link from Sample Reference.

  1. Include the header file.
    1
    2
    3
    4
    5
    6
    7
    #include <iostream>
    //Include the header file of the custom pass API.
    #include "register_custom_pass.h"
    //Include the header file of the new operator.
    #include "all_ops.h"
    If an Ascend C custom operator is used, include the following header files:
    #include "${INSTALL_DIR}/opp/vendors/<customize>/op_proto/inc/op_proto.h"
    

    Replace ${INSTALL_DIR} with the CANN component directory. For example, if the installation is performed by the root user, the default file storage path is /usr/local/Ascend/cann. <customize>: Replace customize with the actual project name of the Ascend C custom operator.

  2. Modify the graph using a custom pass. (The following code is only an example and cannot be executed. When modifying a graph, you can only use the API listed in CustomPassContext, Graph, GNode, PassReceiver, PassRegistrationData, REGISTER_CUSTOM_PASS, and StreamPassContext.)
      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    namespace {
    constexpr const char *kOpNameAdd = "add";
    constexpr const char *kOpNameMatMul = "matmul";
    constexpr const char *kOpNameGEMM = "gemm";
    constexpr const char *kOpNameAlpha = "alpha";
    constexpr const char *kOpNameBeta = "beta";
    constexpr const char *kAttrNameTransposeA = "transpose_a";
    constexpr const char *kAttrNameTransposeB = "transpose_b";
    constexpr int32_t kIndex0 = 0;
    constexpr int32_t kIndex1 = 1;
    constexpr int32_t kIndex2 = 2;
    constexpr int32_t kIndex3 = 3;
    constexpr int32_t kIndex4 = 4;
    
    // 1. Traverse all nodes and search for the MatMul and Add nodes.
    bool FindNodes(GraphPtr &graph, GNode &src_node, GNode &dst_node) {
        auto all_nodes = graph->GetAllNodes();
        bool find_src_node = false;
        bool find_dst_node = false;
        for (auto &node: all_nodes) {
            AscendString node_name;
            auto ret = node.GetName(node_name);
            // Find the MatMul node.
            if (node_name == kOpNameMatMul) {
                src_node = node;
                find_src_node = true;
                cout << "Find src node: MatMul." << endl;
            //Find the Add node.
            } else if (node_name == kOpNameAdd) {
                dst_node = node;
                find_dst_node = true;
                cout << "Find dst node: Add." << endl;
            }
        }
        return (find_src_node && find_dst_node);
    }
    // 2. Check the edge connection between MatMul and Add nodes.
    bool CheckNodesHaveEdge(GraphPtr &graph, const GNode &src_node, const GNode &dst_node) {
        for (auto &[out_node, _]: src_node.GetOutDataNodesAndPortIndexs(kIndex0)) {
            AscendString node_name;
            auto ret = out_node->GetName(node_name);
            if (node_name == kOpNameAdd) {
                return true; // Direct connection
            }
        }
        return false;
    }
    // 3. Create and add a GEMM node.
    void CreateGEMMNode(GraphPtr &graph, const GNode &src_node, GNode &node_gemm) {
        bool transpose_a = false;
        bool transpose_b = false;
        src_node.GetAttr(kAttrNameTransposeA, transpose_a);
        src_node.GetAttr(kAttrNameTransposeB, transpose_b);
        constexpr float kValue1 = 1;
        TensorDesc alpha_desc(ge::Shape({1}), FORMAT_ND, DT_FLOAT);
        Tensor alpha_tensor(alpha_desc, reinterpret_cast<const uint8_t *>(&kValue1), sizeof(float));
        auto alpha = op::Const(kOpNameAlpha).set_attr_value(alpha_tensor);
        TensorDesc beta_desc(ge::Shape({1}), FORMAT_ND, DT_FLOAT);
        Tensor beta_tensor(beta_desc, reinterpret_cast<const uint8_t *>(&kValue1), sizeof(float));
        auto beta = op::Const(kOpNameBeta).set_attr_value(beta_tensor);
    
        auto gemm = op::GEMM(kOpNameGEMM);
        gemm.set_attr_transpose_a(transpose_a)
            .set_attr_transpose_b(transpose_b);
        gemm.update_input_desc_alpha(alpha_desc);
        gemm.update_input_desc_beta(beta_desc);
    
        //Add nodes alpha, beta, and gemm to the graph.
        auto node_alpha = graph->AddNodeByOp(alpha);
        auto node_beta = graph->AddNodeByOp(beta);
        node_gemm = graph->AddNodeByOp(gemm);
    
        //Create data edges between alpha/beta and GEMM.
        auto ret = graph->AddDataEdge(node_alpha, kIndex0, node_gemm, kIndex3);
        ret = graph->AddDataEdge(node_beta, kIndex0, node_gemm, kIndex4);
    }
    // 4. Add the input and output of the new node.
    bool AddInputsAndOutputs(GraphPtr &graph, const GNode &src_node, const GNode &dst_node, GNode &node_gemm) {
        auto [a, a_output_index] = src_node.GetInDataNodesAndPortIndexs(kIndex0);
        auto [b, b_output_index] = src_node.GetInDataNodesAndPortIndexs(kIndex1);
        int32_t add_node_c_input_index = -1;
        for (size_t i = 0; i < dst_node.GetInputsSize(); ++i) {
            auto [in_node, _] = dst_node.GetInDataNodesAndPortIndexs(i);
            AscendString node_name;
            auto ret = in_node->GetName(node_name);
            if (node_name != kOpNameMatMul) {
                add_node_c_input_index = i;
                break;
            }
        }
        if (add_node_c_input_index == -1) {
            return false;
        }
        auto [c, c_output_index] = dst_node.GetInDataNodesAndPortIndexs(add_node_c_input_index);
        //Create data edges between a/b/c and GEMM.
        auto ret = graph->AddDataEdge(*a, a_output_index, node_gemm, kIndex0);
        if (ret != GRAPH_SUCCESS) {
            return false;
        }
        ret = graph->AddDataEdge(*b, b_output_index, node_gemm, kIndex1);
        ret = graph->AddDataEdge(*c, c_output_index, node_gemm, kIndex2);
    
        //Update the input/output descriptions of GEMM.
        TensorDesc input_desc_a;
        ret = src_node.GetInputDesc(kIndex0, input_desc_a);
        ret = node_gemm.UpdateInputDesc(kIndex0, input_desc_a);
    
        TensorDesc input_desc_b;
        ret = src_node.GetInputDesc(kIndex1, input_desc_b);
        ret = node_gemm.UpdateInputDesc(kIndex1, input_desc_b);
    
        TensorDesc input_desc_c;
        ret = dst_node.GetInputDesc(add_node_c_input_index, input_desc_c);
        ret = node_gemm.UpdateInputDesc(kIndex2, input_desc_c);
    
        TensorDesc output_desc_y;
        ret = dst_node.GetOutputDesc(kIndex0, output_desc_y);
        ret = node_gemm.UpdateOutputDesc(kIndex0, output_desc_y);
        return true;
    }
    // 5. Delete the old node and its edge connection, and connect the new GEMM node to the output node.
    void RemoveOldNodesEdgesAndAddGemmOutput(GraphPtr &graph, GNode &src_node, GNode &dst_node, GNode &node_gemm) {
        vector<GNode> node_vec{src_node, dst_node};
        // Delete input edges of the old node.
        for (auto &node: node_vec) {
            for (size_t i = 0; i < node.GetInputsSize(); ++i) {
                auto [in_node, in_id] = node.GetInDataNodesAndPortIndexs(i);
                if (in_node != nullptr) {
                    auto ret = graph->RemoveEdge(*in_node, in_id, node, i);
                }
            }
        }
    
        //Change the output edge of the old node to the new GEMM.
        for (auto &[out_node, out_id]: dst_node.GetOutDataNodesAndPortIndexs(kIndex0)) {
            if (out_node != nullptr) {
                auto ret = graph->RemoveEdge(dst_node, kIndex0, *out_node, out_id);
                ret = graph->AddDataEdge(node_gemm, kIndex0, *out_node, out_id);
            }
        }
        //Delete the old node.
        for (auto &node: node_vec) {
            auto ret = graph->RemoveNode(node);
        }
    }
    } // namespace
    
    // |o>-----------------------------------
    // |o>    a  b
    // |o>    \ /              a   b    c
    // |o>   MatMul  c   ==>   \   |   /
    // |o>     \    /            GEMM
    // |o>      Add
    // |o>-----------------------------------
    // Fusion description: In this example, the MatMul+Add structure on the left of the graph is identified and replaced with a GEMM node on the right using the graph modification API.c
    // Return values of the graph modification API: The graph modification API in this file needs to check return values. For readability, except for the pass entry point function, other functions' graph modification APIs only accept return values without processing them. If you need to check return values, you can call the custom_context.SetErrorMessage("xxx") method.
    
  3. Register the custom graph modification pass.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    graphStatus FuseMatMulAndAddPass(GraphPtr &graph, CustomPassContext &custom_context) {
        cout << "FuseMatMulAndAddPass begin." << endl;
        GNode src_node;
        GNode dst_node;
        // 1. Traverse all nodes and search for the MatMul and Add nodes.
        if (!FindNodes(graph, src_node, dst_node)) {
            cout << "Do not find MatMul or Add node." << endl;
            return GRAPH_SUCCESS;
        }
    
        // 2. Check the edge connection between MatMul and Add nodes.
        if (!CheckNodesHaveEdge(graph, src_node, dst_node)) {
            cout << "There is no edge between src and dst node." << endl;
            return GRAPH_SUCCESS;
        }
    
        // 3. Create and add a GEMM node.
        GNode node_gemm;
        CreateGEMMNode(graph, src_node, node_gemm);
    
        // 4. Add the input and output of the new node.
        if (!AddInputsAndOutputs(graph, src_node, dst_node, node_gemm)) {
            custom_context.SetErrorMessage("Add inputs and outputs failed.");
            return -1;
        }
    
        // 5. Delete the old node and its edge connection, and connect the new GEMM node to the output node.
        RemoveOldNodesEdgesAndAddGemmOutput(graph, src_node, dst_node, node_gemm);
    
        cout << "FuseMatMulAndAddPass end." << endl;
        return GRAPH_SUCCESS;
    }
    // Use the REGISTER_CUSTOM_PASS registration macro to register the graph modification pass and specify the phase when the pass will be called.
    REGISTER_CUSTOM_PASS("FuseMatMulAndAddPass").CustomPassFn(FuseMatMulAndAddPass).Stage(CustomPassStage::kBeforeInferShape);
    

How to Use a Custom Pass

The following describes how to compile the graph modification function into a dynamic library plugin so that the registered pass can be called by the framework at the beginning of graph build. For details, see Sample Usage.

  1. Compile the graph modification function in Example into a dynamic library file whose name ends only with .so.
  2. Copy the .so dynamic library file to the ${INSTALL_DIR}/opp/vendors/xxx/custom_fusion_passes/ directory. (Soft links can be set. The .so file must be readable to executable users.)

    Multiple ${INSTALL_DIR}/opp/vendors/xxx directories are sorted in text order and then traversed to search for the custom_fusion_passes/ subdirectory. The .so files in a single subdirectory are loaded in text order, while the files whose names do not end with .so are skipped during loading.

    • Replace ${INSTALL_DIR} with the CANN component directory. For example, if the installation is performed by the root user, the default file storage path is /usr/local/Ascend/cann.
    • xxx: There must be only one level of custom directory.
    • custom_fusion_passes: The directory cannot contain subdirectories.
  3. The supported entries for building model files include, but are not limited to:
    To check whether the custom pass takes effect, dump the graph before model build by setting the DUMP_GE_GRAPH environment variable and build the model from the following entries:

Result Verification

For details, see Sample Usage > Program Running > Execution Result Check.

After the dump environment variable is set, the graph files such as ge_onnx*.pbtxt are generated in the current path after the program is executed. You can obtain the following two graphs and use visualization software such as Netron to view the graphs.

  • Specifying that the pass execution phase is before the InferShape phase:
    • ge_onnx_xxxx_PreRunBegin.pbtxt: graph before fusion
    • ge_onnx_xxxx_RunCustomPassBeforeInfershape.pbtxt: graph after fusion

      The following uses the example of fusing the MatMul+Add structure into a GEMM node via a custom pass. The figure below shows the graph structure before fusion.

      The following figure shows the graph structure after modification using the custom pass. The MatMul+Add structure has been replaced with a GEMM node.

  • Specifying that the pass execution phase is after the InferShape phase:
    • ge_onnx_xxxx_PrepareAfterInferFormatAndShape.pbtxt: graph before fusion
    • ge_onnx_xxxx_RunCustomPass_AfterInferShape.pbtxt: graph after fusion

    The following uses the example of fusing BatchMatMulV2 into the Transpose+Mul+ReduceSum structure via a custom pass. The figure below shows the graph structure before fusion.

    The following figure shows the graph structure after modification using the custom pass. The BatchMatMulV2 node has been replaced with the Transpose+Mul+ReduceSum structure.

  • Specifying that the pass execution phase is after the built-in original graph fusion pass:
    • ge_onnx_xxxx_OptimizeOriginalGraph_FeGraphFusionAfter.pbtxt: graph before fusion
    • ge_onnx_xxxx_RunCustomPassAfterBuiltinFusionPass.pbtxt: graph after fusion

    The following uses inserting the Abs operator into the Data-Sqrt structure of a subgraph as an example. The figure below shows the graph structure before fusion.

    The following figure shows the graph structure after modification using the custom pass. The Abs operator has been inserted into the Data-Sqrt structure.