实现
算子的IR用于进行算子的描述,包括算子输入输出信息,属性信息等,用于把算子注册到算子原型库中。
算子的IR需要在算子的工程目录的“op_proto/算子名称.h”和 “op_proto/算子名称.cc ”文件中进行实现。
下面详细讲解如何进行算子IR定义文件的实现。
算子IR 头文件.h 注册代码实现
- 宏定义。
使用如下语句进行算子IR注册宏的定义,宏名称固定为GE_OP_OPERATORTYPE_H,OPERATORTYPE为使用REG_OP(OpType)语句中OpType的大写。
#ifndef GE_OP_OPERATORTYPE_H //条件编译 #define GE_OP_OPERATORTYPE_H //进行宏定义
- 包含头文件。
在算子IR实现文件的头部使用预编译命令“#include”将算子注册的头文件包含到算子IR实现的文件中。
#include "graph/operator_reg.h"
operator_reg.h存在于CANN软件安装后文件存储路径的“include/graph/”路径下,包含此头文件,可使用算子类型注册相关的函数、宏、结构体等。
- 原型注册。
Graph Engine(GE)提供REG_OP宏,以“.”链接INPUT、OUTPUT、ATTR等接口注册算子的输入、输出和属性信息,最终以OP_END_FACTORY_REG接口结束,完成算子的注册。
其中输入输出的描述信息顺序需要与算子实现中定义保持一致,ATTR的顺序可变。
注册代码实现如下所示:
namespace ge{ REG_OP(OpType) //算子类型名称 .INPUT(x1, TensorType({ DT_FLOAT, DT_INT32 })) .INPUT(x2, TensorType({ DT_FLOAT, DT_INT32 })) // .DYNAMIC_INPUT(x, TensorType{DT_FLOAT, DT_INT32}) // .OPTIONAL_INPUT(b, TensorType{DT_FLOAT}) .OUTPUT(y, TensorType({ DT_FLOAT, DT_INT32 })) // .DYNAMIC_OUTPUT(y, TensorType{DT_FLOAT, DT_INT32}) .ATTR(x, Type, DefaultValue) // .REQUIRED_ATTR(x, Type) // .GRAPH(z1) // .DYNAMIC_GRAPH(z2) .OP_END_FACTORY_REG(OpType) }
- REG_OP(OpType)
OpType:注册到昇腾AI处理器BS9SX1A AI处理器的自定义算子库的算子类型,可以任意命名但不能和已有的算子命名冲突。
- INPUT(x1, TensorType({ DT_FLOAT,DT_UINT8,... }))
注册算子的输入参信息。
- x:宏参数,算子的输入名称,用户自定义。
- TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输入支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType。
若算子有多个输入,每个输入需要使用一条INPUT(x, TensorType({ DT_FLOAT,DT_UINT8,... }))语句进行描述。
- DYNAMIC_INPUT(x, TensorType{DT_FLOAT, DT_INT32, ...})
算子为动态多输入场景下的输入信息注册。
- x:宏参数,算子的输入名称,图运行时,会根据输入的个数自动生成x0、x1、x2……,序号依次递增。
- TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输入支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType。
- OPTIONAL_INPUT(x, TensorType{DT_FLOAT, ...})
若算子输入为可选输入,可使用此接口进行算子输入的注册。
- x:宏参数,算子输入的名称。
- TensorType{DT_FLOAT, ...}:“{ }”中为此输入支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType。
- OUTPUT(y, TensorType({ DT_FLOAT,DT_UINT8,... }))
注册算子的输出信息。
- y:宏参数,算子的输出名称,用户自定义。
- TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输出支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType。
若算子有多个输出,每个输出需要使用一条OUTPUT(x, TensorType({ DT_FLOAT,DT_UINT8,... }))语句进行注册。
- DYNAMIC_OUTPUT(y, TensorType{DT_FLOAT, DT_INT32})
算子为动态多输出场景下的输出信息注册。
- y:宏参数,算子的输出名称,图运行时,会根据输出的个数自动生成y0、y1、y2……,序号依次递增。
- TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输出支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType。
- ATTR(x, Type, DefaultValue)
注册算子的可选属性,包括算子的属性名称,属性类型以及属性值的默认值,当开发者不设置算子对象的属性值时需要使用默认值。ATTR接口中Type的取值与对应的属性类型请参见原型定义接口(REG_OP)。
例如:ATTR(mode, Int, 1),注册属性mode,属性类型为int64_t,默认值为1。
若算子有多个属性,每个属性需要使用一条ATTR(x, Type, DefaultValue)语句或者REQUIRED_ATTR(x, Type)语句进行注册。
- REQUIRED_ATTR(x, Type)
注册算子的必选属性,包括算子的属性名称与属性类型,无默认值,开发者必须设置算子对象的属性值。此接口中Type的取值与对应的属性类型请参见原型定义接口(REG_OP)。
若算子有多个属性,每个属性需要使用一条ATTR(x, Type, DefaultValue)语句或者REQUIRED_ATTR(x, Type)语句进行注册。
- GRAPH(z1)
注册算子中包含的子图信息,输入z1为子图名称,一般用于控制类算子(分支算子/循环算子等)。
注册完成后,会自动生成子图相关的接口,用户获取子图名称、获取子图描述信息、设置子图描述信息等,具体接口可参见GRAPH,用户可使用生成的相关接口进行IR模型的构建。对于同一个算子,注册的算子子图名称需要保持唯一。
- DYNAMIC_GRAPH(z2)
注册动态算子子图信息,输入z2为子图名称,一般用于控制类算子(分支算子/循环算子等)。
注册完成后,会自动生成动态算子子图的相关接口,用于创建动态子图、设置子图描述信息等,具体接口可参见DYNAMIC_GRAPH,用户可使用生成的相关接口进行IR模型的构建。对于同一个算子,注册的算子子图名称需要保持唯一。
- OP_END_FACTORY_REG(OpType):结束算子注册。OpType与REG_OP(OpType)中的OpType保持一致。
- REG_OP(OpType)
- 结束条件编译。
#endif
算子IR 定义的.cc 注册代码实现
IR实现的cc文件中主要实现如下两个功能:
- 算子参数的校验,实现程序健壮性并提高定位效率。
- 根据算子的输入张量描述、算子逻辑及算子属性,推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息。这样算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。
在/op_proto/算子名称.cc 实现Verify和InferShape方法时不需要声明,直接实现即可。
- 包含头文件。
#include "算子名称.h" #include <vector> #include <string>
表1 头文件说明 头文件
目录
作用
算子名称.h
算子IR 头文件.h 注册代码实现中实现的IR头文件
包含该头文件,可以调用此文件中注册的Operator类的对象op或者Operator类派生出来的子类op。
string
C++标准库。
包含该头文件,可使用string类构造对象并调用string相关接口。
vector
C++标准库。
包含该头文件,可使用vector类模板并调用vector相关接口。
- 实现InferShape方法。
算子IR中InferShape的定义可以使用如下接口:
IMPLEMT_COMMON_INFERFUNC(func_name):自动生成的一个类型为Operator类的对象op,可直接调用Operator类接口进行InferShape的实现,其中,func_name:用户自定义。
- 将输入描述直接赋给输出描述的实现样例如下所示:
1 2 3 4 5 6 7 8 9
IMPLEMT_COMMON_INFERFUNC(CacheUpdateInferShape) { TensorDesc out_desc = op.GetOutputDescByName("x"); out_desc.SetDataType(op.GetInputDescByName("x").GetDataType()); if (op.UpdateOutputDesc("x", out_desc) != GRAPH_SUCCESS) { return GRAPH_FAILED; } return GRAPH_SUCCESS; }
- 输出描述需要根据算子逻辑进行计算,样例如下所示:
IMPLEMT_COMMON_INFERFUNC(MatrixDiagPartInferShape) { Shape shape = op.GetInputDescByName("x").GetShape(); DataType input_dtype = op.GetInputDescByName("x").GetDataType(); Format input_format = op.GetInputDescByName("x").GetFormat(); std::vector<int64_t> dim_vector; int64_t dimsInput_1 = shape.GetDimNum() - 1; int64_t dimsInput_2 = shape.GetDimNum() - 2; int64_t dimNums_1 = shape.GetDim(dimsInput_1); int64_t dimNums_2 = shape.GetDim(dimsInput_2); if (dimNums_1 > dimNums_2) { for (size_t i = 0; i < shape.GetDimNum() - 1; i++) { dim_vector.push_back(shape.GetDim(i)); } } else { for (size_t i = 0; i < shape.GetDimNum() - 2; i++) { dim_vector.push_back(shape.GetDim(i)); } dim_vector.push_back(dimNums_1); } Shape output_shape(dim_vector); TensorDesc td = op.GetOutputDesc("y"); td.SetShape(output_shape); td.SetDataType(input_dtype); td.SetFormat(input_format); (void)op.UpdateOutputDesc("y", td); return GRAPH_SUCCESS; }
- 若存在动态输入输出,输入输出按列表形式处理,样例如下所示:
IMPLEMT_COMMON_INFERFUNC(BatchInfer) { for (size_t i = 0; i < op.GetInputsSize(); ++i) { Shape out_shapes; if (ReplaceDim(op.GetInputDesc(i).GetShape(), 0, ge::UNKNOWN_DIM, out_shapes, op.GetName().c_str()) == GRAPH_FAILED) { return GRAPH_FAILED; } auto y_tensor_type = op.GetDynamicInputDesc("x_tensors", i).GetDataType(); TensorDesc output_desc = op.GetDynamicOutputDesc("y_tensors", i); output_desc.SetShape(out_shapes); output_desc.SetDataType(y_tensor_type); op.UpdateDynamicOutputDesc("y_tensors", i, output_desc); } Shape scalar_shape; Scalar(scalar_shape); TensorDesc y_desc = op.GetOutputDesc("y_id"); y_desc.SetShape(scalar_shape); y_desc.SetDataType(DT_INT64); op.UpdateOutputDesc("y_id", y_desc); std::vector<int64_t> dims = { ge::UNKNOWN_DIM, 3 }; TensorDesc output_desc_batch_index = op.GetOutputDesc("y_index"); output_desc_batch_index.SetShape(Shape(dims)); output_desc_batch_index.SetDataType(DT_INT64); op.UpdateOutputDesc("y_index", output_desc_batch_index); return GRAPH_SUCCESS; } // ReplaceDim函数定义如下 graphStatus ReplaceDim(const Shape& s, int64_t dim_index_in, int64_t new_dim, Shape& out, const char* op_name) { if(shape.GetDims() == UNKNOWN_RANK) { out = Shape(ge::UNKNOWN_SHAPE); return GRAPH_SUCCESS; } int64_t dim_index = dim_index_in; if (dim_index < 0) { dim_index = (int64_t)s.GetDimNum() + dim_index; } std::vector<int64_t> dims = s.GetDims(); dims[dim_index] = new_dim; out = Shape(dims); return GRAPH_SUCCESS; }
- 将输入描述直接赋给输出描述的实现样例如下所示:
- 实现Verify方法。
算子Verify函数的实现使用如下接口:
IMPLEMT_VERIFIER (OpType, func_name)
传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取算子的相关属性,op对象的成员函数可参见2。
- OpType:自定义算子的类型。
- func_name:自定义的verify函数名称。
Verify函数主要校验算子内在关联关系,例如对于多输入算子,多个tensor的dtype需要保持一致,此时需要校验多个输入的dtype,其他情况dtype不需要校验。
实现样例如下所示:
IMPLEMT_VERIFIER(Pow, PowVerify) { DataType input_type_x = op.GetInputDescByName("x").GetDataType(); DataType input_type_y = op.GetInputDescByName("y").GetDataType(); if (input_type_x != input_type_y) { return GRAPH_FAILED; } return GRAPH_SUCCESS; }
- 注册InferShape方法与Verify方法。
调用InferShape注册宏与Verify注册宏完成InferShape方法与Verify方法的注册,如下所示:
COMMON_INFER_FUNC_REG(OpType, func_name); VERIFY_FUNC_REG(OpType, func_name);