适配插件开发(ONNX框架)
简介
您可以参考本章节进行算子适配插件的开发,将基于第三方框架的算子映射成适配昇腾AI处理器BS9SX1A AI处理器的算子,将算子信息注册到GE中。基于ONNX框架的网络运行时,首先会加载并调用GE中的插件信息,将原始框架网络中的算子进行解析并映射成适配昇腾AI处理器BS9SX1A AI处理器中的算子。
下文我们将适配昇腾AI处理器BS9SX1A AI处理器的算子称为CANN算子。
原理介绍
算子插件的实现包含CANN算子类型的注册、原始框架中算子类型的注册以及原始框架中算子属性到CANN算子属性的映射,算子的映射通过Parser模块完成。插件在整网络运行场景下的实现流程如图1所示。
- 首先GE接收到第三方框架的原始网络模型,并进行初始化,网络模型的拓扑图我们简称为图。
- GE从Register注册模块中加载AI CPU算子插件生成的.so文件,在CANN软件安装后文件存储路径的“opp/framework/”路径中。
- 读取算子插件.so中的算子相关信息,并将其注册到算子插件的map文件中(所有算子插件的相关信息都会以map的形式存储到一个文件中)。
- GE向Parser模块发送调用Parser方法的请求。
- Parser模块根据算子类型(OpType)从算子插件的map文件中取出对应的Parser函数,并返回实现函数ParseParamsByOperatorFn给Parser模块,Parser模块根据实现函数将第三方网络算子中的属性映射到CANN算子的属性,即算子原型中的属性定义,完成第三方网络中算子到CANN算子的映射。
- 后续会进行图准备、图拆分及图优化等一系列操作,并将算子编译生成二进制文件。
插件实现
GE提供REGISTER_CUSTOM_OP宏,按照指定的算子名称完成算子的注册。
原始框架为ONNX的自定义算子注册代码:
#include "register/register.h" #include "graph/operator.h" #include "json.hpp" namespace domi { REGISTER_CUSTOM_OP("OpType") .FrameworkType(ONNX) .OriginOpType("OriginOpType") .ParseParamsByOperatorFn(ParseParamByOpFunc) //用来注册解析算子属性的函数 .ImplyType(ImplyType::AI_CPU); }
- 在代码实现文件顶部使用预编译命令“#include”将插件实现函数相关的头文件包含到插件实现源文件中。
- register.h存储在CANN软件安装后文件存储路径的“include/register/”目录下,包含该头文件,可使用算子注册相关类,调用算子注册相关的接口。
- operator.h(可选),存储在CANN软件安装后文件存储路径的“include/graph/”目录下,包含该头文件,可以使用Operator类相关接口,获取算子输入输出及属性等算子信息。
- json.hpp:用于进行ONNX数据定义的解析,将String类型的算子参数定义转换为json格式。
若样例工程中未提供“json.hpp”文件,用户可以自行下载,并将“json.hpp”放在工程可以找到的任意路径下,然后包含此头文件即可,下载路径可参见json.hpp。
- REGISTER_CUSTOM_OP:注册自定义算子,OpType作为注册到GE中的算子类型,可以任意命名但不能和已有的算子命名冲突,且需要与3中的OpType保持一致。
- FrameworkType:ONNX代表原始框架为ONNX。
- OriginOpType:算子在原始框架中的类型。例如自定义算子OpTypeA,对应ONNX算子库版本opset_version=11的原始框架类型为“ai.onnx::11::OpTypeA”。
- ParseParamsByOperatorFn(ParseParamByOpFunc):用来注册解析算子属性的函数,需要用户自定义实现回调函数ParseParamByOpFunc。
回调函数ParseParamByOpFunc的声明如下所示:
Status ParseParamByOpFunc(const ge::Operator& op_src, ge::Operator& op_dest)
- ParseParamByOpFunc:函数名称,用户自定义,需要保持唯一。
- op_src:ONNX框架定义的Operator类对象,包含ONNX模型中自定义的算子属性信息,定义来源ONNX框架的原始模型文件。
- op_dest:CANN算子数据结构,保存算子信息,Operator类的详细描述请参见Operator类。
ONNX原始模型中,属性为repeated message类型,如下所示:
message NodeProto { repeated string input = 1; // namespace Value repeated string output = 2; // namespace Value string name = 3; // namespace Node string op_type = 4; // namespace Operator string domain = 7; // namespace Domain // Additional named attributes. repeated AttributeProto attribute = 5; }
GE对属性进行解析时,对于repeated message类型的参数,可使用GetAttr(const char *name, ge::AscendString &attr_value)接口获取其属性值,然后将AscendString类型的属性值转换为String类型,再将其转换为json格式进行属性字段的解析。
实现如下所示:
using namespace ge; using json = nlohmann::json; namespace domi { namespace { const int kTypeFloat = 1; } Status ParseOnnxParamsLeakyRelu(const ge::Operator& op_src, ge::Operator& op_dest) { // trans op_src to op_dest // if op_src get required attr failed, need to return Failed // if op_src get optional attr failed, need to return Failed or set a default value float negative_slope = 0.01f; string negative_slope_str; AscendString attrs_string; // 使用固定属性名称“attribute”获取ONNX算子中的属性,并赋值给AscendString类型对象 if (ge::GRAPH_SUCCESS == op_src.GetAttr("attribute", attrs_string)) { // 转换为json格式 json attrs = json::parse(attrs_string.GetString()); for (json attr : attrs["attribute"]) { if (attr["name"] == "alpha" && attr["type"] == kTypeFloat) { negative_slope_str = attr["f"]; // float type in json has accuracy loss, so we use string type to store it negative_slope = atof(negative_slope_str.c_str()); } } } op_dest.SetAttr("negative_slope", negative_slope); return SUCCESS; }
- 当前版本GetAttr与SetAttr接口不支持对原始文件中数据类型为double和uint64的字段进行解析。
- 使用ATC工具执行模型转换时,对属性的获取情况不会进行强校验。所以进行算子适配插件实现时,若用户调用GetAttr失败,建议根据算子实际情况增加相应的处理逻辑,例如,针对必选属性,可返回失败,针对可选属性,可设置默认值。
- ImplyType:指定算子的实现方式。ImplyType::AI_CPU表示该算子是AI CPU算子;ImplyType::TVM表示该算子是TBE算子。
父主题: 算子适配