昇腾社区首页
中文
注册

使用自定义Pass修改Graph

改图功能主要有两种调用方式,用户可以在构建完Graph后,直接调用改图接口修改Graph,也可以将改图函数封装为自定义Pass,并通过REGISTER_CUSTOM_PASS注册宏进行改图Pass注册,通过把改图函数编译成动态库插件方式,注册的Pass可以在指定的阶段被调用。

功能介绍

一个改图函数可看作是一个自定义Pass,用户可调用REGISTER_CUSTOM_PASS注册宏,按照指定的Pass名称完成Pass的注册,通过把改图函数编译成动态库插件方式,注册的Pass在指定阶段被调用,示例代码如下:
1
2
3
4
5
6
7
8
#include "register_custom_pass.h"
// 用户自定义改图函数
Status CustomPassFunc(GraphPtr &graph, CustomPassContext &custom_context) {
    // 此处定义图修改具体行为
    return GRAPH_SUCCESS;
}
// 改图Pass注册,注册后的Pass在指定阶段被调用
REGISTER_CUSTOM_PASS("pass_name").CustomPassFn(CustomPassFunc).Stage(CustomPassStage::kBeforeInferShape);
  • register_custom_pass.h:存储在“CANN软件安装目录/latest/include/register/”目录下,包含该头文件,可使用Pass注册相关类,使用Pass注册相关接口。
  • Status:成功返回ge::GRAPH_SUCCESS,返回其他全为失败。建议使用小于0的值作为返回的错误码,大于0的值可能会和框架已使用的错误码产生冲突。
  • CustomPassFunc:自定义Pass的入口函数,详情请参见回调函数CustomPassFunc
  • graph:自定义Pass要修改的图,类型为GraphPtr。
  • custom_context:CustomPassContext类对象,可参考CustomPassContext提供的方法。
  • REGISTER_CUSTOM_PASS:注册自定义Pass,"pass_name"可任意命名,详情请参见REGISTER_CUSTOM_PASS
  • CustomPassFn:自定义Pass的执行函数,详情请参见CustomPassFn
  • Stage:指定Pass执行阶段,详情请参见Stage

如果用户改图过程中,需要替换成其他功能的算子,但是该算子CANN不支持,可以通过如下方式自定义该算子:

通过Ascend C自定义该算子,详情请参见Ascend C算子开发指南

将该算子开发完成后,才能正常使用自定义Pass修改graph的功能。

开发示例

此处以MatMul+Add融合为GEMM自定义Pass为例,详细介绍如何通过自定义Pass修改Graph,详细可以参见样例源码

样例仓还提供了“如何通过自定义Pass把Tile+Concat图结构修改为Concat+Tile+Concat图结构”的详细说明,可以从完整样例参考获取链接。

  1. 包含的头文件。
    1
    2
    3
    4
    5
    6
    7
    #include <iostream>
    // 自定义Pass接口头文件
    #include "register_custom_pass.h"
    // 新增算子头文件
    #include "all_ops.h"
    // 如果使用Ascend C自定义了算子,需要包含如下头文件:
    #include "CANN软件安装目录/latest/opp/vendors/customize/op_proto/inc/op_proto.h"
    

    其中:CANN软件安装目录:请修改为CANN软件包的实际安装路径。customize:请修改为Ascend C自定义算子实际工程名。

  2. 使用自定义Pass修改Graph。(如下代码仅为示例,不可执行;修改Graph时,只能使用CustomPassContextGraph~GNodePassReceiverPassRegistrationDataREGISTER_CUSTOM_PASSStreamPassContext中的接口)
      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    namespace {
    constexpr const char *kOpNameAdd = "add";
    constexpr const char *kOpNameMatMul = "matmul";
    constexpr const char *kOpNameGEMM = "gemm";
    constexpr const char *kOpNameAlpha = "alpha";
    constexpr const char *kOpNameBeta = "beta";
    constexpr const char *kAttrNameTransposeA = "transpose_a";
    constexpr const char *kAttrNameTransposeB = "transpose_b";
    constexpr int32_t kIndex0 = 0;
    constexpr int32_t kIndex1 = 1;
    constexpr int32_t kIndex2 = 2;
    constexpr int32_t kIndex3 = 3;
    constexpr int32_t kIndex4 = 4;
    
    // 1.遍历所有节点,寻找MatMul和Add节点
    bool FindNodes(GraphPtr &graph, GNode &src_node, GNode &dst_node) {
        auto all_nodes = graph->GetAllNodes();
        bool find_src_node = false;
        bool find_dst_node = false;
        for (auto &node: all_nodes) {
            AscendString node_name;
            auto ret = node.GetName(node_name);
            if (node_name == kOpNameMatMul) {
                src_node = node;
                find_src_node = true;
                cout << "Find src node: MatMul." << endl;
            } else if (node_name == kOpNameAdd) {
                dst_node = node;
                find_dst_node = true;
                cout << "Find dst node: Add." << endl;
            }
        }
        return (find_src_node && find_dst_node);
    }
    // 2.判断MatMul和Add节点是否有连边关系
    bool CheckNodesHaveEdge(GraphPtr &graph, const GNode &src_node, const GNode &dst_node) {
        for (auto &[out_node, _]: src_node.GetOutDataNodesAndPortIndexs(kIndex0)) {
            AscendString node_name;
            auto ret = out_node->GetName(node_name);
            if (node_name == kOpNameAdd) {
                return true;
            }
        }
        return false;
    }
    // 3.创建和添加GEMM节点
    void CreateGEMMNode(GraphPtr &graph, const GNode &src_node, GNode &node_gemm) {
        bool transpose_a = false;
        bool transpose_b = false;
        src_node.GetAttr(kAttrNameTransposeA, transpose_a);
        src_node.GetAttr(kAttrNameTransposeB, transpose_b);
        constexpr float kValue1 = 1;
        TensorDesc alpha_desc(ge::Shape({1}), FORMAT_ND, DT_FLOAT);
        Tensor alpha_tensor(alpha_desc, reinterpret_cast<const uint8_t *>(&kValue1), sizeof(float));
        auto alpha = op::Const(kOpNameAlpha).set_attr_value(alpha_tensor);
        TensorDesc beta_desc(ge::Shape({1}), FORMAT_ND, DT_FLOAT);
        Tensor beta_tensor(beta_desc, reinterpret_cast<const uint8_t *>(&kValue1), sizeof(float));
        auto beta = op::Const(kOpNameBeta).set_attr_value(beta_tensor);
    
        auto gemm = op::GEMM(kOpNameGEMM);
        gemm.set_attr_transpose_a(transpose_a)
            .set_attr_transpose_b(transpose_b);
        gemm.update_input_desc_alpha(alpha_desc);
        gemm.update_input_desc_beta(beta_desc);
    
        auto node_alpha = graph->AddNodeByOp(alpha);
        auto node_beta = graph->AddNodeByOp(beta);
        node_gemm = graph->AddNodeByOp(gemm);
    
        auto ret = graph->AddDataEdge(node_alpha, kIndex0, node_gemm, kIndex3);
        ret = graph->AddDataEdge(node_beta, kIndex0, node_gemm, kIndex4);
    }
    // 4.添加新节点的输入输出
    bool AddInputsAndOutputs(GraphPtr &graph, const GNode &src_node, const GNode &dst_node, GNode &node_gemm) {
        auto [a, a_output_index] = src_node.GetInDataNodesAndPortIndexs(kIndex0);
        auto [b, b_output_index] = src_node.GetInDataNodesAndPortIndexs(kIndex1);
        int32_t add_node_c_input_index = -1;
        for (size_t i = 0; i < dst_node.GetInputsSize(); ++i) {
            auto [in_node, _] = dst_node.GetInDataNodesAndPortIndexs(i);
            AscendString node_name;
            auto ret = in_node->GetName(node_name);
            if (node_name != kOpNameMatMul) {
                add_node_c_input_index = i;
                break;
            }
        }
        if (add_node_c_input_index == -1) {
            return false;
        }
        auto [c, c_output_index] = dst_node.GetInDataNodesAndPortIndexs(add_node_c_input_index);
        auto ret = graph->AddDataEdge(*a, a_output_index, node_gemm, kIndex0);
        if (ret != GRAPH_SUCCESS) {
            return false;
        }
        ret = graph->AddDataEdge(*b, b_output_index, node_gemm, kIndex1);
        ret = graph->AddDataEdge(*c, c_output_index, node_gemm, kIndex2);
    
        TensorDesc input_desc_a;
        ret = src_node.GetInputDesc(kIndex0, input_desc_a);
        ret = node_gemm.UpdateInputDesc(kIndex0, input_desc_a);
    
        TensorDesc input_desc_b;
        ret = src_node.GetInputDesc(kIndex1, input_desc_b);
        ret = node_gemm.UpdateInputDesc(kIndex1, input_desc_b);
    
        TensorDesc input_desc_c;
        ret = dst_node.GetInputDesc(add_node_c_input_index, input_desc_c);
        ret = node_gemm.UpdateInputDesc(kIndex2, input_desc_c);
    
        TensorDesc output_desc_y;
        ret = dst_node.GetOutputDesc(kIndex0, output_desc_y);
        ret = node_gemm.UpdateOutputDesc(kIndex0, output_desc_y);
        return true;
    }
    // 5.删除旧节点和其连边关系,连接新GEMM节点和输出节点
    void RemoveOldNodesEdgesAndAddGemmOutput(GraphPtr &graph, GNode &src_node, GNode &dst_node, GNode &node_gemm) {
        vector<GNode> node_vec{src_node, dst_node};
        for (auto &node: node_vec) {
            for (size_t i = 0; i < node.GetInputsSize(); ++i) {
                auto [in_node, in_id] = node.GetInDataNodesAndPortIndexs(i);
                if (in_node != nullptr) {
                    auto ret = graph->RemoveEdge(*in_node, in_id, node, i);
                }
            }
        }
    
        for (auto &[out_node, out_id]: dst_node.GetOutDataNodesAndPortIndexs(kIndex0)) {
            if (out_node != nullptr) {
                auto ret = graph->RemoveEdge(dst_node, kIndex0, *out_node, out_id);
                ret = graph->AddDataEdge(node_gemm, kIndex0, *out_node, out_id);
            }
        }
    
        for (auto &node: node_vec) {
            auto ret = graph->RemoveNode(node);
        }
    }
    } // namespace
    
    // |o>-----------------------------------
    // |o>    a  b
    // |o>    \ /              a   b    c
    // |o>   MatMul  c   ==>   \   |   /
    // |o>     \    /            GEMM
    // |o>      Add
    // |o>-----------------------------------
    // 融合说明:本例识别上图中左边的MatMul+Add结构并通过图修改接口替换为右边的单个GEMM节点
    // 改图接口返回值说明:本文件中的改图接口需要判断返回值, 基于可读性考虑除了Pass入口函数外,其他函数中的改图接口只接收返回值,但不增加返回值处理代码;如需判断返回值,可配合使用custom_context.SetErrorMessage("xxx")方法
    graphStatus FuseMatMulAndAddPass(GraphPtr &graph, CustomPassContext &custom_context) {
        cout << "FuseMatMulAndAddPass begin." << endl;
        GNode src_node;
        GNode dst_node;
        // 1.遍历所有节点,寻找MatMul和Add节点
        if (!FindNodes(graph, src_node, dst_node)) {
            cout << "Do not find MatMul or Add node." << endl;
            return GRAPH_SUCCESS;
        }
    
        // 2.判断MatMul和Add节点是否有连边关系
        if (!CheckNodesHaveEdge(graph, src_node, dst_node)) {
            cout << "There is no edge between src and dst node." << endl;
            return GRAPH_SUCCESS;
        }
    
        // 3.创建和添加GEMM节点
        GNode node_gemm;
        CreateGEMMNode(graph, src_node, node_gemm);
    
        // 4.添加新节点的输入输出
        if (!AddInputsAndOutputs(graph, src_node, dst_node, node_gemm)) {
            custom_context.SetErrorMessage("Add inputs and outputs failed.");
            return -1;
        }
    
        // 5.删除旧节点和其连边关系,连接新GEMM节点和输出节点
        RemoveOldNodesEdgesAndAddGemmOutput(graph, src_node, dst_node, node_gemm);
    
        cout << "FuseMatMulAndAddPass end." << endl;
        return GRAPH_SUCCESS;
    }
    
    REGISTER_CUSTOM_PASS("FuseMatMulAndAddPass").CustomPassFn(FuseMatMulAndAddPass).Stage(CustomPassStage::kBeforeInferShape);
    

如何使用自定义Pass

完成上述自定义Pass后,本节简单介绍如何把改图函数编译成动态库插件方式,以便注册的Pass在图编译的最开始被框架调用。详细使用说明请参见样例使用指导

  1. 开发示例中的改图函数编译成仅以".so"结尾的动态库文件。
  2. 把上述".so"动态库文件复制到${INSTALL_DIR}/opp/vendors/xxx/custom_fusion_passes/目录下。(支持设置软链接的方式;".so"文件对可执行用户,需要有可读权限)

    多个"${INSTALL_DIR}/opp/vendors/xxx"目录按照文本序排序后遍历寻找"custom_fusion_passes/"子目录,单个子目录内的".so"按照文本序加载,非".so"结尾的文件在加载时跳过。

    • ${INSTALL_DIR}请替换为CANN软件安装后文件存储路径。若安装的Ascend-cann-toolkit软件包,以root安装举例,则安装后文件存储路径为:/usr/local/Ascend/ascend-toolkit/latest。
    • xxx:有且仅有一层自定义目录。
    • custom_fusion_passes:该目录下不能有子目录。
  3. 支持但不限于如下几种入口编译模型文件:
    如果要查看上述自定义Pass有没有生效,在编译模型前,需要dump图进行查看:在执行之前设置DUMP_GE_GRAPH环境变量,然后使用如下入口编译模型:

结果验证

请参见样例使用指导>程序运行>检查执行结果

设置了dump环境变量后,程序执行完毕,会在当前路径生成ge_onnx*.pbtxt等图文件,用户可以获取如下两张图,然后使用Netron等可视化软件查看:

  • 指定Pass执行阶段在InferShape之前:
    • ge_onnx_xxxxxxxx_RunCustomPassBegin.pbtxt:融合前的图
    • ge_onnx_xxxxxxxx_RunCustomPassEnd.pbtxt:融合后的图

      MatMul+Add融合为GEMM自定义Pass为例,查看融合前的图结构为:

      通过自定义Pass修改后的图结构如下所示,可以看出MatMul+Add结构已经替换为单个GEMM节点。

  • 指定Pass执行阶段在InferShape之后:
    • ge_onnx_xxxx_PrepareAfterInferFormatAndShape.pbtxt:融合前的图
    • ge_onnx_xxxx_RunCustomPass_AfterInferShape.pbtxt:融合后的图

    BatchMatMulV2融合为Transpose+Mul+ReduceSum自定义Pass为例,查看融合前的图结构为:

    通过自定义Pass修改后的图结构如下所示,可以看出BatchMatMulV2已经替换为Transpose+Mul+ReduceSum三个节点。