Pattern Mapping-based Pass

This section provides a group of APIs used for graph-based matching and replacement during pass build, improving the development efficiency of custom fusion pass. The logic of fusion consists of three phases below:

1. Matching: A pattern defines a graph structure for subgraph matching and search.

2. Decision-making: The system determines whether fusion can be performed based on the matching results and specific conditions.

3. Replacement: If fusion is permitted, the system rewrites the graph with the fused structure.

Figure 1 Logical architecture

Figure 1 shows the logical architecture. The core concepts are described as follows:

  • Pattern: a template or rule set that describes specific subgraph's structural characteristics. The graph matching algorithm uses this Pattern to search for the subgraphs that satisfy the defined rules or structural constraints.
  • PatternMatcher: a core object that runs the matching algorithm. It takes a Pattern as input and searches the graph for subgraphs that satisfy the defined structural constraints.
  • GraphRewriter: a core object that modifies a graph. It receives the matched subgraph edges and the target graph after replacement, and replaces the original subgraph nodes with the target node structure to reconstruct the graph.

You can implement custom fusion pass by inheriting the base classes provided by the GE and overriding their methods. Afterward, register the pass using the provided registration macro and specify its execution phase. The GE offers two types of base classes for inheritance in the following application scenarios.

  • General subgraph fusion (one-to-one or complex topology replacement): A complete subgraph structure needs to be matched and replaced with another subgraph.

    The PatternFusionPass class is inherited to implement the custom fusion pass, and the REG_FUSION_PASS registration macro is used to register the pass and the execution phase is then specified.

  • Single-node replacement (one-to-N replacement): The DecomposePass class is inherited to implement the custom fusion pass, and the REG_DECOMPOSE_PASS registration macro is used to register the pass and the execution phase is then specified.

Scenarios

  • Fusion pass development in the general subgraph fusion scenario (one-to-one or complex topology replacement)

    This part describes the core data structure PatternFusionPass involved in this scenario, and then describes the three to-be-rewritten functions: Patterns, MeetRequirements, and Replacement. Ultimately, this part describes how to register a pass and specify its execution phase.

    The PatternFusionPass class is declared as follows:

    1
    2
    3
    4
    5
    6
    7
    8
    class PatternFusionPass : public FusionBasePass {
     public:
      Status Run(GraphPtr &graph, CustomPassContext &pass_context) override;
     protected:
      virtual std::vector<PatternUniqPtr> Patterns() = 0;
      virtual bool MeetRequirements(const std::unique_ptr<MatchResult> &match_result);
      virtual GraphUniqPtr Replacement(const std::unique_ptr<MatchResult> &match_result) = 0;
    };
    

    The Run function calls Patterns to obtain the topology Pattern of the template and matches the Pattern with the nodes in the target graph one by one; it calls MeetRequirements to determine whether the matched Pattern needs to be replaced; and it calls Replacement to obtain the target structure and replace the Pattern (graph structure) that meet the replacement conditions.

    You can inherit the PatternFusionPass class and override functions Patterns, MeetRequirements, and Replacement to develop a custom fusion pass. Functions are described as follows.

    Function

    Description

    Must Override

    Patterns

    Defines the template topology used to match the target graph and returns one or more graph structure pointers.

    Yes

    MeetRequirements

    Filters the matched graph structures based on conditions after Patterns is run. The match result is input and a Boolean value is returned.

    No. The value true is returned by default.

    Replacement

    Defines the replacement structure. The match result is input, and the graph pointer is returned.

    Yes

    • Patterns

      Patterns defines one or more template topologies used to match the target graph. Use the following to construct a directed acyclic graph (DAG) as instructed in EsGraphBuilder to express the Pattern.

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      13
      14
      std::vector<PatternUniqPtr> Patterns() override {
        std::vector<PatternUniqPtr> patterns;
        // Use EsGraphBuilder to construct a Pattern.
        auto graph_builder = es::EsGraphBuilder("pattern");
        // Define the Pattern.
        // ...
        // Initialize the Pattern object.
        auto graph = graph_builder.BuildAndReset({xxx});
        auto pattern = std::make_unique<Pattern>(std::move(*graph));
        patterns.emplace_back(std::move(pattern));
        // Add multiple patterns to Patterns.
        // ...
        return patterns;
      }
      

      EsGraphBuilder is a graph builder class used to construct a computational graph. You are advised to use C/C++ APIs (where APIs used for defining inputs, constants, and operators are provided) to define Pattern. The following is an example of defining a ReLU single-operator Pattern using the ES API:

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      std::vector<PatternUniqPtr> patterns;
      // Create an EsGraphBuilder instance to construct a computational graph named pattern.
      auto graph_builder = es::EsGraphBuilder("pattern");
      auto data = graph_builder.CreateInput(0);
      auto relu = es::Relu(data);
      // Construct and reset the graph, and use {relu} as the output node.
      auto graph = graph_builder.BuildAndReset({relu});
      // Move the graph to the Pattern constructor to create a pattern object.
      auto pattern = std::make_unique<Pattern>(std::move(*graph));
      patterns.emplace_back(std::move(pattern));
      

      The Pattern to be matched must be self-contained (that is, except the boundary outputs, none of internal operator outputs may be consumed by nodes outside the Pattern). Non-self-contained Pattern will not be matched.

      In addition to using EsGraphBuilder to build a Pattern, the GE provides two APIs used for defining a Pattern at a finer granularity.

      • CaptureTensor

        During the definition, a tensor in the Pattern can be captured so that the tensor can be obtained in sequence in MatchResult. The method declaration is as follows. The input parameter node_output is of the NodeIo type and consists of a node and an index, indicating an output of a node.

        1
        2
        3
        4
        5
        6
        7
        8
        // CaptureTensor declaration
        Pattern &CaptureTensor(const NodeIo &node_output);
        
        // NodeIo structure
        struct NodeIo {
          GNode node;
          int64_t index;
        };
        

        The following is an example of calling CaptureTensor to capture data:

         1
         2
         3
         4
         5
         6
         7
         8
         9
        10
        11
        12
        std::vector<PatternUniqPtr> patterns;
        // Create an EsGraphBuilder instance to construct a computational graph named pattern.
        auto graph_builder = es::EsGraphBuilder("pattern");
        auto data = graph_builder.CreateInput(0);
        auto relu = es::Relu(data);
        // Construct a computational graph which contains only the data -> ReLU(relu) structure.
        auto graph = graph_builder.BuildAndReset({relu});
        // Create a Pattern instance and initialize it using the constructed graph.
        auto pattern = std::make_unique<Pattern>(std::move(*graph));
        // Call CaptureTensor to capture data.
        pattern->CaptureTensor({*relu.GetProducer(), 0})
        patterns.emplace_back(std::move(pattern));
        
      • PatternMatcherConfig

        PatternMatcherConfig can be passed to PatternFusionPass (that is, the pass is configurable). This allows for you to enable capabilities such as Const value matching and IR attribute and value matching. The constructor of the base class PatternFusionPass is as follows:

        1
        explicit PatternFusionPass(std::unique_ptr<PatternMatcherConfig> match_config);
        

        Use PatternMatcherConfigBuilder to construct PatternMatcherConfig. The PatternMatcherConfigBuilder class provides two functions as the switch of the matching capabilities.

        • EnableConstValueMatch: Const value matching is enabled. During the matching, the system compares Const/Constant values defined in the Pattern and a match is found only when the values are the same.
        • EnableIrAttrMatch: IR attribute and value matching is enabled. During pattern matching, the pass matches the number and values of IR attributes on a node in the pattern.

        In the following example, the constructor for enabling Const value matching for the custom pass class named CustomFusionPass is used.

        1
        explicit CustomFusionPass() : PatternFusionPass(PatternMatcherConfigBuilder().EnableConstValueMatch().Build()) {}
        
    • MeetRequirements

      MeetRequirements filters the matching results obtained from Patterns. As described in the implementation of the Run function, each type of matching result (described in MatchResult) is used as the input parameter of MeetRequirements. Through MatchResult, you can obtain the matching results. The returned Boolean value is used as the basis for determining whether to replace the matching result, as described in the following:

      1
      2
      3
      4
      5
      6
      7
      8
      9
      bool MeetRequirements(const std::unique_ptr<MatchResult> &match_result) override {
        // Use the input match_result to filter the matching results.
        // If the defined rules or structural constraints are satisfied, true is returned.
        if (IsSatisfy(match_result)) {
          return true;
        }
        // Otherwise, false is returned.
        return false;
      }
      

      MatchResult is a matching result class. It contains information such as nodes and edges. You can use a MatchResult member function to obtain the matching results for filtering. The following is an example of using the GetCapturedTensor member function to check whether the ReLU output has a dynamic shape:

       1
       2
       3
       4
       5
       6
       7
       8
       9
      10
      11
      12
      NodeIo relu_output;
      // Try to obtain the first captured output tensor from match_result and store it to relu_output.
      if(match_result->GetCapturedTensor(0,relu_output) != GRAPH_SUCCESS){
        return false;
      }
      TensorDesc relu_out_tensor_desc;
      // Obtain the output tensor description from relu_output.
      relu_output.node.GetOutputDesc(relu_output.index, relu_out_tensor_desc);
      if (relu_out_tensor_desc.GetShape().GetShapeSize() != -1){
        return false;
      }
      return true;
      
    • Replacement

      Replacement defines what to replace matched subgraph with. The system replaces the part that matches Patterns and the value returned by MeetRequirements for the part is true. Similar to Patterns, Replacement uses EsGraphBuilder to define the structure. The detailed operation will not be described here.

      1
      2
      3
      4
      5
      6
      GraphUniqPtr Replacement(const std::unique_ptr<MatchResult> &match_result) override {
        auto replacement_graph_builder = es::EsGraphBuilder("replacement");
        // Define the replacement structure.
        // ...
        return replacement_graph_builder.BuildAndReset({r_a});
      }
      

      If the pass registration phase is after InferShape, you need to call GeUtils:: (InferShape) in Replacement. In addition, if you want to use GeUtils:: (CheckNodeSupportOnAicore) to determine whether the target structure is supported, the function must be called after InferShape.

    Registration of the custom fusion pass

    After defining a fusion pass, you need to use the registration macro REG_FUSION_PASS to register the pass and specify its execution phase. The following is an example of registering the custom pass named CustomFusionPass which runs before the InferShape phase (this can be done through kBeforeInferShape in the following).

    1
    REG_FUSION_PASS(CustomFusionPass).Stage(CustomPassStage::kBeforeInferShape);
    

    For details about each phase, see Stage.

  • Fusion pass development (one-to-N node replacement)

    In this scenario, a pass inherits from the base class DecomposePass, and Patterns is not needed to define Pattern because a single node is to be replaced. Instead, the operator type is directly passed to the constructor, as described in the following:

    1
    2
    3
    class CustomOne2NPass : public DecomposePass {
     public:
      CustomOne2NPass(const std::vector<AscendString> &op_types) : DecomposePass(op_types) {}
    

    Similar to the general subgraph fusion scenario, the pass that inherits from DecomposePass also needs to override MeetRequirements and Replacement. However, the input parameter is not MatchResult but GNode (the node matched in the graph by comparing op_types passed during pass build).

    1
    2
    3
    4
    5
    6
    bool MeetRequirements(const GNode &matched_node) override {
        ...
    }
    GraphUniqPtr Replacement(const GNode &matched_node) override {
        ...
    } 	
    

    Registration of the custom fusion pass

    The following is an example of using the registration macro REG_DECOMPOSE_PASS to initialize CustomOne2NPass with Conv2D as op_types. In the example, the execution phase for the registered pass is kAfterInferShape::

    1
    REG_DECOMPOSE_PASS(CustomOne2NPass, {"Conv2D"}).Stage(CustomPassStage::kAfterInferShape);
    

Example

Assume that you need to modify a graph by fusing a MatMul+Add structure into a GEMM node via a custom pass in general fusion pass development. For details, see the sample source code. The sample repository provides more samples. For details, see fusion pass samples.

The following describes the graph structure before and after the modification. In this example, the MatMul+Add structure on the left of the graph is replaced with a GEMM node on the right using the graph modification API.

// |o>-----------------------------------
// |o>      a  b
// |o>      \ /                a    b    c
// |o>     MatMul     c   ==>   \   |   /
// |o>        \      /            GEMM
// |o>           Add
// |o>-----------------------------------
  1. Include the header file.
    1
    2
    3
    4
    5
    #include <iostream>
    // Include the header file for the custom fusion pass API
    #include "ge/fusion/pass/pattern_fusion_pass.h"
    // Includes the header file for the ES interface.
    #include "es_all_ops.h"
    
  2. Modify the graph using a custom pass.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    class FuseMatMulAndAddPass : public PatternFusionPass {
     protected:
      //Rewrite Patterns.
      std::vector<PatternUniqPtr> Patterns() override {
        std::cout << "Define pattern for FuseMatMulAndAddPass" << std::endl;
        std::vector<PatternUniqPtr> patterns;
        // Create an EsGraphBuilder instance to construct a computational graph named pattern0.
        auto graph_builder0 = es::EsGraphBuilder("pattern0");
        auto [a0, b0, c0] = graph_builder0.CreateInputs<3>();
        auto matmul0 = es::MatMul(a0, b0);
        auto add0 = es::Add(matmul0, c0);
        // Construct and reset the graph.
        auto graph0 = graph_builder0.BuildAndReset({add0});
        auto pattern0 = std::make_unique<Pattern>(std::move(*graph0));
        patterns.emplace_back(std::move(pattern0));
    
        return patterns;
      }
      //Rewrite Replacement.
      GraphUniqPtr Replacement(const std::unique_ptr<MatchResult> &match_result) override {
        std::cout << "Define replacement for FuseMatMulAndAddPass" << std::endl;
        // Construct the replacement graph.
        auto replace_graph_builder = es::EsGraphBuilder("replacement");
        auto [r_a, r_b, r_c] = replace_graph_builder.CreateInputs<3>();
        auto alpha_const = replace_graph_builder.CreateScalar(1);
        auto beta_const = replace_graph_builder.CreateScalar(1);
        auto gemm = es::GEMM(r_a, r_b, r_c, alpha_const, beta_const);
        // Construct and reset the graph.
        return replace_graph_builder.BuildAndReset({gemm});
      }
    };
    
  3. Register the custom fusion pass.
    1
    2
    // Use the REG_FUSION_PASS registration macro to register the graph modification pass and specify the phase when the pass will be called.
    REG_FUSION_PASS(FuseMatMulAndAddPass).Stage(CustomPassStage::kBeforeInferShape);
    

How to Use a Custom Pass

The following describes how to compile the graph modification function into a dynamic library plugin so that the registered pass can be called by the framework in the graph build phase. For details, see Sample Usage.

  1. Compile the graph modification function in Example into a dynamic library file whose name ends only with .so.
  2. After the build is successful, run the make install command to install the .so dynamic library file to the ${INSTALL_DIR}/opp/vendors/xxx/custom_fusion_passes/ directory. (Soft links can be set. The .so file must be readable to executable users.)

    Multiple ${INSTALL_DIR}/opp/vendors/xxx directories are sorted in text order and then traversed to search for the custom_fusion_passes/ subdirectory. The .so files in a single subdirectory are loaded in text order, while the files whose names do not end with .so are skipped during loading.

    • Replace ${INSTALL_DIR} with the CANN component directory. For example, if the installation is performed by the root user, the default file storage path is /usr/local/Ascend/cann.
    • xxx: There must be only one level of custom directory.
    • custom_fusion_passes: The directory cannot contain subdirectories.
  3. The supported entries for building model files include, but are not limited to:
    To check whether the custom pass takes effect, dump the graph before model build by setting the DUMP_GE_GRAPHenvironment variable and build the model from the following entries:

Result Verification

For details, see Sample Usage > Program Running > Training Result Check.

After the dump environment variable is set, the graph files such as ge_onnx*.pbtxt are generated in the current path after the program is executed. You can obtain the following two graphs and use visualization software such as Netron to view the graphs. Pass execution phase is before InferShape in this case.
  • ge_onnx_xxxx_PreRunBegin.pbtxt: graph before fusion
  • ge_onnx_xxxx_RunCustomPassBeforeInfershape.pbtxt: graph after fusion

The figure below shows the graph structure before fusion.

The following figure shows the graph structure after modification using the custom pass. The MatMul+Add structure has been replaced with a GEMM node.