算子的IR用于进行算子的描述,包括算子输入输出信息,属性信息等,用于把算子注册到算子原型库中。
算子的IR需要在算子的工程目录的/op_proto/算子名称.h 和 /op_proto/算子名称.cc 文件中进行实现。
下面详细讲解如何进行算子IR定义文件的实现。
使用如下语句进行算子IR注册宏的定义,宏名称固定为GE_OP_OPERATORTYPE_H,OPERATORTYPE为使用REG_OP(OpType)语句中OpType的大写。
#ifndef GE_OP_OPERATORTYPE_H //条件编译 #define GE_OP_OPERATORTYPE_H //进行宏定义
在算子IR实现文件的头部使用预编译命令“#include”将算子注册的头文件包含到算子IR实现的文件中。
#include "graph/operator_reg.h"
operator_reg.h存在于CANN软件安装后文件存储路径的“include/graph/”路径下,包含此头文件,可使用算子类型注册相关的函数、宏、结构体等。
Graph Engine(GE)提供REG_OP宏,以“.”链接INPUT、OUTPUT、ATTR等接口注册算子的输入、输出和属性信息,最终以OP_END_FACTORY_REG接口结束,完成算子的注册。
注册代码实现如下所示:
namespace ge{ REG_OP(OpType) //算子类型名称 .INPUT(x1, TensorType({ DT_FLOAT, DT_INT32 })) .INPUT(x2, TensorType({ DT_FLOAT, DT_INT32 })) // .DYNAMIC_INPUT(x, TensorType{DT_FLOAT, DT_INT32}) // .OPTIONAL_INPUT(b, TensorType{DT_FLOAT}) .OUTPUT(y, TensorType({ DT_FLOAT, DT_INT32 })) // .DYNAMIC_OUTPUT(y, TensorType{DT_FLOAT, DT_INT32}) .ATTR(x, Type, DefaultValue) // .REQUIRED_ATTR(x, Type) // .GRAPH(z1) // .DYNAMIC_GRAPH(z2) .OP_END_FACTORY_REG(OpType) }
OpType:注册到昇腾AI处理器的自定义算子库的算子类型,可以任意命名但不能和已有的算子命名冲突。
若算子有多个输入,每个输入需要使用一条INPUT(x, TensorType({ DT_FLOAT,DT_UINT8,... }))语句进行描述。
若算子有多个输出,每个输出需要使用一条OUTPUT(x, TensorType({ DT_FLOAT,DT_UINT8,... }))语句进行注册。
注册算子的可选属性,包括算子的属性名称,属性类型以及属性值的默认值,当开发者不设置算子对象的属性值时需要使用默认值。ATTR接口中Type的取值与对应的属性类型请参见原型定义接口(REG_OP)。
例如:ATTR(mode, Int, 1),注册属性mode,属性类型为int64_t,默认值为1。
若算子有多个属性,每个属性需要使用一条ATTR(x, Type, DefaultValue)语句或者REQUIRED_ATTR(x, Type)语句进行注册。
注册算子的必选属性,包括算子的属性名称与属性类型,无默认值,开发者必须设置算子对象的属性值。此接口中Type的取值与对应的属性类型请参见原型定义接口(REG_OP)。
若算子有多个属性,每个属性需要使用一条ATTR(x, Type,DefaultValue)语句或者REQUIRED_ATTR(x, Type)语句进行注册。
注册算子中包含的子图信息,输入z1为子图名称,一般用于控制类算子(分支算子/循环算子等)。
注册完成后,会自动生成子图相关的接口,用户获取子图名称、获取子图描述信息、设置子图描述信息等,具体接口可参见GRAPH,用户可使用生成的相关接口进行IR模型的构建。对于同一个算子,注册的算子子图名称需要保持唯一。
注册动态算子子图信息,输入z2为子图名称,一般用于控制类算子(分支算子/循环算子等)。
注册完成后,会自动生成动态算子子图的相关接口,用于创建动态子图、设置子图描述信息等,具体接口可参见DYNAMIC_GRAPH,用户可使用生成的相关接口进行IR模型的构建。对于同一个算子,注册的算子子图名称需要保持唯一。
#endif
IR实现的cc文件中主要实现如下两个功能:
在“op_proto/算子名称.cc”实现Verify和InferShape方法时不需要声明,直接实现即可。
#include "算子名称.h" #include <vector> #include <string>
头文件 |
目录 |
作用 |
---|---|---|
算子名称.h |
算子IR 头文件.h 注册代码实现中实现的IR头文件 |
包含该头文件,可以调用此文件中注册的Operator类的对象op或者Operator类派生出来的子类op。 |
string |
C++标准库。 |
包含该头文件,可使用string类构造对象并调用string相关接口。 |
vector |
C++标准库。 |
包含该头文件,可使用vector类模板并调用vector相关接口。 |
算子IR中InferShape的定义可以使用如下接口:
IMPLEMT_COMMON_INFERFUNC(func_name):此接口自动生成一个类型为Operator类的对象op,开发者可直接调用Operator类接口进行InferShape的实现。其中,func_name:用户自定义。
1 2 3 4 5 6 7 8 9 10 11 |
IMPLEMT_COMMON_INFERFUNC(SoftmaxInferShape) { TensorDesc tensordesc_output = op.GetOutputDescByName("y"); tensordesc_output.SetShape(op.GetInputDescByName("x").GetShape()); tensordesc_output.SetDataType(op.GetInputDescByName("x").GetDataType()); tensordesc_output.SetFormat(op.GetInputDescByName("x").GetFormat()); (void)op.UpdateOutputDesc("y", tensordesc_output); return GRAPH_SUCCESS; } |
IMPLEMT_COMMON_INFERFUNC(NotEqualInferShape) { Shape x_shape = op.GetInputDescByName("x1").GetShape(); Shape y_shape = op.GetInputDescByName("x2").GetShape(); TensorDesc td = op.GetOutputDescByName("y"); std::vector<int64_t> dims_x = x_shape.GetDims(); std::vector<int64_t> dims_y = y_shape.GetDims(); if (dims_x.size() < dims_y.size()) { std::vector<int64_t> dims_tmp = dims_x; dims_x = dims_y; dims_y = dims_tmp; } if (dims_x.size() != dims_y.size()) { int dec = dims_x.size() - dims_y.size(); for (int i = 0; i < dec; i++) { dims_y.insert(dims_y.begin(), (int64_t)1); } } std::vector<int64_t> dim_vec; for (size_t i = 0; i < dims_x.size(); i++) { if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) { printf( "The %s op dimensions does not match the broadcast rule(%lu %lu).",op.GetName().c_str(), dims_x[i], dims_y[i]); } int64_t dims = dims_x[i] > dims_y[i] ? dims_x[i] : dims_y[i]; dim_vec.push_back(dims); } td.SetShape(ge::Shape(dim_vec)); td.SetDataType(DT_BOOL); (void)op.UpdateOutputDesc("y", td); return GRAPH_SUCCESS; }
算子Verify函数的实现使用如下接口:
IMPLEMT_VERIFIER (OpType, func_name)
传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取算子的相关属性,op对象的成员函数可参见2。
Verify函数主要校验算子内在关联关系,例如对于多输入算子,多个tensor的dtype需要保持一致,此时需要校验多个输入的dtype,其他情况dtype不需要校验。
实现样例如下所示:
IMPLEMT_VERIFIER(Pow, PowVerify) { DataType input_type_x = op.GetInputDescByName("x").GetDataType(); DataType input_type_y = op.GetInputDescByName("y").GetDataType(); if (input_type_x != input_type_y) { return GRAPH_FAILED; } return GRAPH_SUCCESS; }
调用InferShape注册宏与Verify注册宏完成InferShape方法与Verify方法的注册,如下所示:
COMMON_INFER_FUNC_REG(OpType, func_name); VERIFY_FUNC_REG(OpType, func_name);