昇腾社区首页
中文
注册

实现

算子的IR用于进行算子的描述,包括算子输入输出信息,属性信息等,用于把算子注册到算子原型库中。

算子的IR需要在算子的工程目录的“op_proto/算子名称.h”和 “op_proto/算子名称.cc ”文件中进行实现。

下面详细讲解如何进行算子IR定义文件的实现。

算子IR 头文件.h 注册代码实现

  1. 宏定义。

    使用如下语句进行算子IR注册宏的定义,宏名称固定为GE_OP_OPERATORTYPE_H,OPERATORTYPE为使用REG_OP(OpType)语句中OpType的大写。

    #ifndef GE_OP_OPERATORTYPE_H       //条件编译
    #define GE_OP_OPERATORTYPE_H       //进行宏定义
  2. 包含头文件。

    在算子IR实现文件的头部使用预编译命令“#include”将算子注册的头文件包含到算子IR实现的文件中。

    #include "graph/operator_reg.h"

    operator_reg.h存在于CANN软件安装后文件存储路径的“include/graph/”路径下,包含此头文件,可使用算子类型注册相关的函数、宏、结构体等。

  3. 原型注册。

    Graph Engine(GE)提供REG_OP宏,以“.”链接INPUT、OUTPUT、ATTR等接口注册算子的输入、输出和属性信息,最终以OP_END_FACTORY_REG接口结束,完成算子的注册。

    其中输入输出的描述信息顺序需要与算子实现中定义保持一致,ATTR的顺序可变。

    注册代码实现如下所示:

    namespace ge{
    REG_OP(OpType) //算子类型名称
        .INPUT(x1, TensorType({ DT_FLOAT, DT_INT32 }))
        .INPUT(x2, TensorType({ DT_FLOAT, DT_INT32 })) 
        // .DYNAMIC_INPUT(x, TensorType{DT_FLOAT, DT_INT32})
        // .OPTIONAL_INPUT(b, TensorType{DT_FLOAT})        
        .OUTPUT(y, TensorType({ DT_FLOAT, DT_INT32 })) 
        // .DYNAMIC_OUTPUT(y, TensorType{DT_FLOAT, DT_INT32})
        .ATTR(x, Type, DefaultValue)
        // .REQUIRED_ATTR(x, Type)
        // .GRAPH(z1)
        // .DYNAMIC_GRAPH(z2)
        .OP_END_FACTORY_REG(OpType)
    }
    • REG_OP(OpType)

      OpType:注册到昇腾AI处理器BS9SX1A AI处理器的自定义算子库的算子类型,可以任意命名但不能和已有的算子命名冲突。

    • INPUT(x1, TensorType({ DT_FLOAT,DT_UINT8,... }))
      注册算子的输入参信息。
      • x:宏参数,算子的输入名称,用户自定义。
      • TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输入支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType

      若算子有多个输入,每个输入需要使用一条INPUT(x, TensorType({ DT_FLOAT,DT_UINT8,... }))语句进行描述。

    • DYNAMIC_INPUT(x, TensorType{DT_FLOAT, DT_INT32, ...})
      算子为动态多输入场景下的输入信息注册。
      • x:宏参数,算子的输入名称,图运行时,会根据输入的个数自动生成x0、x1、x2……,序号依次递增。
      • TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输入支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType
    • OPTIONAL_INPUT(x, TensorType{DT_FLOAT, ...})
      若算子输入为可选输入,可使用此接口进行算子输入的注册。
      • x:宏参数,算子输入的名称。
      • TensorType{DT_FLOAT, ...}:“{ }”中为此输入支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType
    • OUTPUT(y, TensorType({ DT_FLOAT,DT_UINT8,... }))
      注册算子的输出信息。
      • y:宏参数,算子的输出名称,用户自定义。
      • TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输出支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType

      若算子有多个输出,每个输出需要使用一条OUTPUT(x, TensorType({ DT_FLOAT,DT_UINT8,... }))语句进行注册。

    • DYNAMIC_OUTPUT(y, TensorType{DT_FLOAT, DT_INT32})
      算子为动态多输出场景下的输出信息注册。
      • y:宏参数,算子的输出名称,图运行时,会根据输出的个数自动生成y0、y1、y2……,序号依次递增。
      • TensorType({ DT_FLOAT,DT_UINT8,... }):“{ }”中为此输出支持的数据类型的列表,支持的数据类型请参见DataType,TensorType提供了一些接口指定支持的数据类型,详细定义请参见TensorType
    • ATTR(x, Type, DefaultValue)

      注册算子的可选属性,包括算子的属性名称,属性类型以及属性值的默认值,当开发者不设置算子对象的属性值时需要使用默认值。ATTR接口中Type的取值与对应的属性类型请参见原型定义接口(REG_OP)

      例如:ATTR(mode, Int, 1),注册属性mode,属性类型为int64_t,默认值为1。

      若算子有多个属性,每个属性需要使用一条ATTR(x, Type, DefaultValue)语句或者REQUIRED_ATTR(x, Type)语句进行注册。

    • REQUIRED_ATTR(x, Type)

      注册算子的必选属性,包括算子的属性名称与属性类型,无默认值,开发者必须设置算子对象的属性值。此接口中Type的取值与对应的属性类型请参见原型定义接口(REG_OP)

      若算子有多个属性,每个属性需要使用一条ATTR(x, Type, DefaultValue)语句或者REQUIRED_ATTR(x, Type)语句进行注册。

    • GRAPH(z1)

      注册算子中包含的子图信息,输入z1为子图名称,一般用于控制类算子(分支算子/循环算子等)。

      注册完成后,会自动生成子图相关的接口,用户获取子图名称、获取子图描述信息、设置子图描述信息等,具体接口可参见GRAPH,用户可使用生成的相关接口进行IR模型的构建。对于同一个算子,注册的算子子图名称需要保持唯一。

    • DYNAMIC_GRAPH(z2)

      注册动态算子子图信息,输入z2为子图名称,一般用于控制类算子(分支算子/循环算子等)。

      注册完成后,会自动生成动态算子子图的相关接口,用于创建动态子图、设置子图描述信息等,具体接口可参见DYNAMIC_GRAPH,用户可使用生成的相关接口进行IR模型的构建。对于同一个算子,注册的算子子图名称需要保持唯一。

    • OP_END_FACTORY_REG(OpType):结束算子注册。OpType与REG_OP(OpType)中的OpType保持一致。
  4. 结束条件编译。
    #endif

算子IR 定义的.cc 注册代码实现

IR实现的cc文件中主要实现如下两个功能:

  • 算子参数的校验,实现程序健壮性并提高定位效率。
  • 根据算子的输入张量描述、算子逻辑及算子属性,推理出算子的输出张量描述,包括张量的形状、数据类型及数据排布格式等信息。这样算子构图准备阶段就可以为所有的张量静态分配内存,避免动态内存分配带来的开销。

/op_proto/算子名称.cc 实现Verify和InferShape方法时不需要声明,直接实现即可。

  1. 包含头文件。
    #include "算子名称.h"     
    #include <vector>
    #include <string>
    表1 头文件说明

    头文件

    目录

    作用

    算子名称.h

    算子IR 头文件.h 注册代码实现中实现的IR头文件

    包含该头文件,可以调用此文件中注册的Operator类的对象op或者Operator类派生出来的子类op。

    string

    C++标准库。

    包含该头文件,可使用string类构造对象并调用string相关接口。

    vector

    C++标准库。

    包含该头文件,可使用vector类模板并调用vector相关接口。

  2. 实现InferShape方法。

    算子IR中InferShape的定义可以使用如下接口:

    IMPLEMT_COMMON_INFERFUNC(func_name):自动生成的一个类型为Operator类的对象op,可直接调用Operator类接口进行InferShape的实现,其中,func_name:用户自定义。

    • 将输入描述直接赋给输出描述的实现样例如下所示:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      IMPLEMT_COMMON_INFERFUNC(CacheUpdateInferShape) 
      {  
        TensorDesc out_desc = op.GetOutputDescByName("x");  
        out_desc.SetDataType(op.GetInputDescByName("x").GetDataType());  
        if (op.UpdateOutputDesc("x", out_desc) != GRAPH_SUCCESS) {    
           return GRAPH_FAILED;  
        }  
        return GRAPH_SUCCESS;
      }
      
    • 输出描述需要根据算子逻辑进行计算,样例如下所示:
      IMPLEMT_COMMON_INFERFUNC(MatrixDiagPartInferShape)
      {    
        Shape shape = op.GetInputDescByName("x").GetShape();    
        DataType input_dtype = op.GetInputDescByName("x").GetDataType();    
        Format input_format = op.GetInputDescByName("x").GetFormat();   
        std::vector<int64_t> dim_vector;    
        int64_t dimsInput_1 = shape.GetDimNum() - 1;    
        int64_t dimsInput_2 = shape.GetDimNum() - 2;    
        int64_t dimNums_1 = shape.GetDim(dimsInput_1);    
        int64_t dimNums_2 = shape.GetDim(dimsInput_2);    
        if (dimNums_1 > dimNums_2) {        
          for (size_t i = 0; i < shape.GetDimNum() - 1; i++) {            
            dim_vector.push_back(shape.GetDim(i));
          }    
        } else {        
          for (size_t i = 0; i < shape.GetDimNum() - 2; i++) {            
            dim_vector.push_back(shape.GetDim(i));        
          }        
          dim_vector.push_back(dimNums_1);    
        }    
        Shape output_shape(dim_vector);    
        TensorDesc td = op.GetOutputDesc("y");    
        td.SetShape(output_shape);    
        td.SetDataType(input_dtype);    
        td.SetFormat(input_format);    
        (void)op.UpdateOutputDesc("y", td);    
        return GRAPH_SUCCESS;
      }
    • 若存在动态输入输出,输入输出按列表形式处理,样例如下所示:
      IMPLEMT_COMMON_INFERFUNC(BatchInfer)
      {
          for (size_t i = 0; i < op.GetInputsSize(); ++i) {
              Shape out_shapes;
              if (ReplaceDim(op.GetInputDesc(i).GetShape(), 
                      0, ge::UNKNOWN_DIM, out_shapes, op.GetName().c_str()) == GRAPH_FAILED) {
                  return GRAPH_FAILED;
              }
              auto y_tensor_type = op.GetDynamicInputDesc("x_tensors", i).GetDataType();
              TensorDesc output_desc = op.GetDynamicOutputDesc("y_tensors", i);
              output_desc.SetShape(out_shapes);
              output_desc.SetDataType(y_tensor_type);
              op.UpdateDynamicOutputDesc("y_tensors", i, output_desc);
          }
      
          Shape scalar_shape;
          Scalar(scalar_shape);
          TensorDesc y_desc = op.GetOutputDesc("y_id");
          y_desc.SetShape(scalar_shape);
          y_desc.SetDataType(DT_INT64);
          op.UpdateOutputDesc("y_id", y_desc);
      
          std::vector<int64_t> dims = { ge::UNKNOWN_DIM, 3 };
          TensorDesc output_desc_batch_index = op.GetOutputDesc("y_index");
          output_desc_batch_index.SetShape(Shape(dims));
          output_desc_batch_index.SetDataType(DT_INT64);
          op.UpdateOutputDesc("y_index", output_desc_batch_index);
          return GRAPH_SUCCESS;
      }
      // ReplaceDim函数定义如下
      graphStatus ReplaceDim(const Shape& s, 
                             int64_t dim_index_in, 
                             int64_t new_dim, 
                             Shape& out, 
                             const char* op_name) {
        if(shape.GetDims() == UNKNOWN_RANK) {
          out = Shape(ge::UNKNOWN_SHAPE);
          return GRAPH_SUCCESS;
        }
        int64_t dim_index = dim_index_in;
        if (dim_index < 0) {
          dim_index = (int64_t)s.GetDimNum() + dim_index;
        }
        std::vector<int64_t> dims = s.GetDims();
        dims[dim_index] = new_dim;
        out = Shape(dims);
        return GRAPH_SUCCESS;
      }
  3. 实现Verify方法。

    算子Verify函数的实现使用如下接口:

    IMPLEMT_VERIFIER (OpType, func_name)

    传入的OpType为基于Operator类派生出来的子类,会自动生成一个类型为此子类的对象op,可以使用子类的成员函数获取算子的相关属性,op对象的成员函数可参见2

    • OpType:自定义算子的类型。
    • func_name:自定义的verify函数名称。

    Verify函数主要校验算子内在关联关系,例如对于多输入算子,多个tensor的dtype需要保持一致,此时需要校验多个输入的dtype,其他情况dtype不需要校验。

    实现样例如下所示:

    IMPLEMT_VERIFIER(Pow, PowVerify) {
      DataType input_type_x = op.GetInputDescByName("x").GetDataType();
      DataType input_type_y = op.GetInputDescByName("y").GetDataType();
      if (input_type_x != input_type_y) {
        return GRAPH_FAILED;
      }
      return GRAPH_SUCCESS;
    }
  4. 注册InferShape方法与Verify方法。

    调用InferShape注册宏与Verify注册宏完成InferShape方法与Verify方法的注册,如下所示:

    COMMON_INFER_FUNC_REG(OpType, func_name);  
    VERIFY_FUNC_REG(OpType, func_name);

    func_name即为23中的func_name。