Development Workflow
Follow the development process based on Project-based Operator Development. In addition to the operator implementation file described in Project-based Operator Development, the code file for integrating operators to the graph also needs to be delivered. This section provides only the guide to developing the code file for integrating operators into a graph.
Assume that the following figure shows the network model to be used. You may want to call operators one by one and obtain the output tensor based on the input tensor to complete the network running. However, during the actual network model generation in graph mode, the tensor shape and data type are inferred first. In this way, the data type and shape of each tensor can be known before the graph is run, and the correctness of each tensor can be verified in advance. In addition, the output tensor description of the operator is inferred in advance, including the tensor shape, data type, and data format. In the preparation phase of operator graph construction, memory can be statically allocated to all tensors to avoid overhead caused by dynamic memory allocation.
After shape and type inference is performed on the following network model, the inference information in the gray shading box can be obtained.
In addition to the tiling implementation, the following additional implementation code needs to be provided when integrating operators into a graph:
- Type inference: infers the output type of an operator based on the input type, logic, and attributes of the operator.
- Shape inference: infers the output shape of an operator based on the input shape, logic, and attributes of the operator.
- Dependency declaration: Some operators need to depend on a specific value of an input during InferShape. These operators are called data-dependent operators, and the corresponding input is referred as data-dependent input. When registering an operator of this type, you need to declare that the data of the operator depends on the input.
The following table lists the requirements of different types of operators on the preceding implementation code.
|
Category |
Code Requirements |
|---|---|
|
The output shape can be inferred based on the input shape. |
|
|
The output shape can be inferred only based on the input value, that is, the data-dependent operator. For example, for the Reshape operator, the output shape can be inferred only based on the input value of the shape. |
|
During actual development, the fixed type and shape inference prototypes are used to implement the inference function, and then the SetInferShape and SetInferDataType APIs are used to associate the corresponding shape inference function. The following is an example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
namespace ge { static graphStatus InferShape(gert::InferShapeContext *context) { ... return GRAPH_SUCCESS; } static graphStatus InferDataType(gert::InferDataTypeContext *context) { ... return ge::GRAPH_SUCCESS; } } // namespace ge namespace ops { class AddCustom : public OpDef { public: AddCustom(const char* name) : OpDef(name) { this->Input("x") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("y") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("z") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); // Determine whether registration is required based on the operator calling mode. Registration is required in graph mode. this->SetInferShape(ge::InferShape); this->SetInferDataType(ge::InferDataType); this->AICore() .SetTiling(optiling::TilingFunc); // Replace it with the actual Ascend AI Processor model. this->AICore().AddConfig("ascendxxx"); } }; OP_ADD(AddCustom); } // namespace ops |
Type Inference
The following uses the AddCustom operator as an example to describe the implementation of InferDataType. In this example, the data type of the output tensor is the same as that of the input tensor. Therefore, you can directly assign the data type of any input tensor to the output tensor.
1 2 3 4 5 6 7 8 |
namespace ge { static graphStatus InferDataType(gert::InferDataTypeContext* context) { const auto inputDataType = context->GetInputDataType(0); context->SetOutputDataType(0, inputDataType); return ge::GRAPH_SUCCESS; } } // namespace ge |
The following example provides a more flexible type inference example. When the input type is DT_INT4, the output type is DT_INT32.
1 2 3 4 5 6 |
ge::graphStatus InferDataTypeForFoo(gert::InferDataTypeContext* context) { if (context->GetInputDataType(0) == DT_INT4) { context->SetOutputDataType(0, DT_INT32); } } |
Shape Inference
Simple shape inference logic can be expressed using the Follow API, for example, when the output shape is the same as the input shape. In the following example, if the output is y1 and the input is x1, set the Follow mode to SHAPE. In this case, the shape of y1 is the same as that of x1.
1 2 3 4 5 6 7 8 9 10 11 12 13 |
this->Input("x1") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x2") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("y1") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND}) .Follow("x1", FollowType::SHAPE); |
If the expression cannot be expressed through Follow in the prototype definition, you need to write the InferShape function, which has definite prototype and accepts a InferShapeContext as the input. The input and output shape pointers can be obtained from the context. The input shape is of the const type. Therefore, when InferShape is used, the input shape is read-only and cannot be modified. If InferShape is successful, ge::GRAPH_SUCCESS is returned. If other values are returned, the inference fails. If the inference fails, the execution process ends and exits.
The following uses the ReShape operator as an example to describe the implementation of InferShape. Based on the value of the first input (shape input), the Reshape operator transforms the shape of the 0th input (x input) and outputs the transformed shape to the 0th output (y output). The InferShape implementation of Reshape is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
ge::graphStatus InferShapeForReshape(InferShapeContext *context) { const gert::Shape *x_shape = context->GetInputShape(0); // Obtain the shape of input 0. const gert::Tensor *shape_tensor = context->GetInputTensor(1); //: Obtain the tensor of the first input. gert::Shape *output_shape = context->GetOutputShape(0); if (x_shape == nullptr || shape_tensor == nullptr || output_shape == nullptr) { // Defensive programming. In an undesirable scenario, an error is printed and a failure message is returned. return ge::GRAPH_FAILED; } auto reshape_size = static_cast<int32_t>(shape_tensor->GetShapeSize()); if (reshape_size < 1) { // Defensive programming. In an undesirable scenario, an error is printed and a failure message is returned. return ge::GRAPH_FAILED; } // According to the prototype information, the shape input of Reshape supports INT32 and INT64. Enter the corresponding template function to perform reshaping based on the type. if (shape_tensor->GetDataType() == ge::DT_INT32) { int32_t *reshape_data = shape_tensor->GetData<int32_t>(); return ReshapeInferShapeImpl<int32_t>(reshape_data, *x_shape, *output_shape, reshape_size); } else { int64_t *reshape_data = shape_tensor->GetData<int64_t>(); return ReshapeInferShapeImpl<int64_t>(reshape_data, *x_shape, *output_shape, reshape_size); } } |
The public keyword of InferShapeContext is inherited from ExtendedKernelContext. Therefore, the methods provided in ExtendedKernelContext, such as the APIs for obtaining the operator type, name, and attributes, can be called in the InferShapeContext instance.
- The InferShape function and the Follow API cannot be used together. Otherwise, the InferShape function should be preferred. Ensure that all output shapes can be inferred from the InferShape function.
- To ensure efficiency, the framework does not initialize the output shape when the InferShape function is called. Therefore, in the InferShape function, the output can be considered as uninitialized. If you want to operate the output shape in Append mode during InferShape, you need to clear DimNum of the output shape to prevent undefined behavior.
Obtains attributes and inputs during InferShape.
During InferShape and tiling, the IR attribute value of the operator can be obtained through the context instance. The IR attribute refers to the attribute defined during IR registration. The following uses the TransData operator as an example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
namespace ops { class TransData : public OpDef { public: explicit TransData(const char *name) : OpDef(name) { this->Input("src") ... this->Output("dst") ... this->Attr("src_format") .AttrType(REQUIRED) .String(); this->Attr("dst_format") .AttrType(REQUIRED) .String(); this->Attr("group") .AttrType(OPTIONAL) .Int(1); ... } }; OP_ADD(TransData); } // namespace ops |
The src_format, dst_format, and group attributes are declared in the prototype definition. You can obtain the operator attributes as follows:
1 2 3 4 5 6 7 8 9 10 11 12 |
ge::graphStatus ExampleGetTransDataAttr(TilingContext *context) { // Obtain all attributes. const RuntimeAttrs *attrs = context->GetAttrs(); ASSERT_NOT_NULL(attrs); // Use the index to obtain attributes based on the sequence in the prototype definition. The index starts from 0. const char *src_format = attrs->GetAttrPointer<char>(0); // Obtain src_format, which is the first attribute. Therefore, index is 0. const char *dst_format = attrs->GetAttrPointer<char>(1); // Obtain dst_format, which is the second attribute. Therefore, index is 1. const int64_t group = attrs->GetAttrPointer<int64_t>(2); // Obtain group, which is the third attribute. Therefore, the index is 2. return ge::GRAPH_SUCCESS; } |
The input and output are indexed by index instead of string name. For operators with OPTIONAL and DYNAMIC input, the specific input may fail to be indexed by index after instantiation. The following uses the DynamicRNNV3 operator as an example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
namespace ops { class DynamicRNNV3 : public OpDef { public: explicit DynamicRNNV3(const char *name) : OpDef(name) { this->Input("x") .ParamType(REQUIRED) ... this->Input("w") .ParamType(REQUIRED) ... this->Input("b") .ParamType(REQUIRED) ... this->Input("seq_length") .ParamType(OPTIONAL) ... this->Input("init_h") .ParamType(OPTIONAL) ... this->Input("init_c") .ParamType(OPTIONAL) ... this->Input("wci") .ParamType(OPTIONAL) ... this->Input("wcf") .ParamType(OPTIONAL) ... this->Input("mask") .ParamType(OPTIONAL) ... this->Input("mask") .ParamType(OPTIONAL) ... this->Input("project") .ParamType(OPTIONAL) ... ... } }; OP_ADD(DynamicRNNV3); } // namespace ops |
The DynamicRNNV3 operator has multiple consecutive optional inputs. As a result, the indexes of init_h and subsequent inputs are uncertain after instantiation. For this type of operator, you can use GetOptionalInputShape to pass the index corresponding to the prototype to obtain data such as the input shape. The following uses InferShape as an example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
ge::graphStatus InferShapeForDynamicRNNV3(InferShapeContext *context) { // The first two inputs are not affected by optional or dynamic. You can obtain the input shape using the common method. auto x_shape = context->GetInputShape(0); auto w_shape = context->GetInputShape(1); if (x_shape == nullptr || w_shape == nullptr) { return ge::GRAPH_FAILED; } int64_t state_size = 0; // In the prototype definition, project is the 11th input (counted from 0). constexpr int64_t kProjectInputIndex = 11; // Affected by the preceding optional input, the input index after project instantiation is uncertain. You can use GetOptionalInputShape to obtain the corresponding input shape, // The input parameter of GetOptionalInputShape is the corresponding index on the prototype. auto project_shape = context->GetOptionalInputShape(kProjectInputIndex); if (project_shape != nullptr) { if (project_shape->GetDimNum() < 2) { return ge::GRAPH_FAILED; } state_size = project_shape->GetDim(1); } // More infershape logic... return ge::GRAPH_SUCCESS; } |
For dynamic input, there may be one or more instantiated inputs. For this type of input, it can be obtained as follows:
1 2 3 |
// ir_index: index of the input in the prototype definition, starting from 0 // relative_index: relative index after instantiation, starting from 0. For example, if three DYNAMIC_INPUT instances are instantiated and the second DYNAMIC_INPUT instance needs to be used, the value of relative_index is 1. auto shape = context->GetDynamicInputShape(ir_index, relative_index); |
The methods of obtaining optional and dynamic inputs in this section can be invoked in the InferShape and Tiling functions.
Data Dependency
Generally, after the input shape is available, the operator can derive the output shape by using InferShape. However, some operators need to depend on a specific value of an input during InferShape. These operators are called data-dependent operators, and the corresponding input is referred as data-dependent input. The Reshape operator is used as an example. It adjusts the input shape based on the shape input description. Therefore, the Reshape operator depends on the shape input value. This type of operator needs to declare the corresponding input as the data-dependent input through the ValueDepend API during prototype definition.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
namespace ops { class Reshape : public OpDef { public: explicit Reshape(const char *name) : OpDef(name) { ... this->Input("shape") .ParamType(REQUIRED) ... .ValueDepend(REQUIRED) // Declare the shape input of the ReShape operator as the data-dependent input. ... } }; OP_ADD(Reshape); } // namespace ops |
Based on the value of the first input (shape input), the Reshape operator transforms the shape of the 0th input (x input) and outputs the transformed shape to the 0th output (y output). The InferShape implementation of Reshape is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
ge::graphStatus InferShapeForReshape(InferShapeContext *context) { const gert::Shape *x_shape = context->GetInputShape(0); // Obtain the shape of input 0. const gert::Tensor *shape_tensor = context->GetInputTensor(1); //: Obtain the tensor of the first input. gert::Shape *output_shape = context->GetOutputShape(0); if (x_shape == nullptr || shape_tensor == nullptr || output_shape == nullptr) { // Defensive programming. In an undesirable scenario, an error is printed and a failure message is returned. return ge::GRAPH_FAILED; } auto reshape_size = static_cast<int32_t>(shape_tensor->GetShapeSize()); if (reshape_size < 1) { // Defensive programming. In an undesirable scenario, an error is printed and a failure message is returned. return ge::GRAPH_FAILED; } // Reshape implementation template<typename T> ge::graphStatus ReshapeInferShapeImpl(const T *reshape_dims, const Shape &x_shape, Shape &output_shape, int32_t reshape_rank) { constexpr T UNKNOWN_DIM = -1; // Set the number of dimensions output by the operator to reshape_rank. output_shape.SetDimNum(reshape_rank); auto x_shape_size = x_shape.GetShapeSize(); int64_t output_shapesize = 1; size_t unknown_dim_idx = std::numeric_limits<size_t>::max(); for (int32_t i = 0; i < reshape_rank; i++) { if (reshape_dims[i] != UNKNOWN_DIM) { // The dimension value of an axis after reshaping is not -1. output_shape.SetDim(i, reshape_dims[i]); // Set the output dimension value to the dimension value after reshaping. output_shapesize *= reshape_dims[i]; // Compute the number of output elements. } else { output_shape.SetDim(i, 1); // The dimension value of an axis after reshaping is -1. The output dimension value is temporarily set to 1. After subsequent computation, check whether a definite value can be deduced and record the index of the unknown dimension. unknown_dim_idx = i; } } if (unknown_dim_idx == std::numeric_limits<size_t>::max() && output_shapesize == x_shape_size) { return ge::GRAPH_SUCCESS; // If there is no unknown dimension and the shape size of output is the same as the shape size of input x, a success message is returned. } else if (unknown_dim_idx != std::numeric_limits<size_t>::max() && x_shape_size % output_shapesize == 0) { output_shape.SetDim(unknown_dim_idx, x_shape_size / output_shapesize); // There is an unknown dimension. The unknown dimension value is dynamically adjusted based on the input shape to keep the total number of elements unchanged. return ge::GRAPH_SUCCESS; } return ge::GRAPH_FAILED; } // According to the prototype information, the shape input of Reshape supports INT32 and INT64. Enter the corresponding template function to perform reshaping based on the type. if (shape_tensor->GetDataType() == ge::DT_INT32) { int32_t *reshape_data = shape_tensor->GetData<int32_t>(); return ReshapeInferShapeImpl<int32_t>(reshape_data, *x_shape, *output_shape, reshape_size); } else { int64_t *reshape_data = shape_tensor->GetData<int64_t>(); return ReshapeInferShapeImpl<int64_t>(reshape_data, *x_shape, *output_shape, reshape_size); } } |
- The GetInputTensor API can be called to obtain the corresponding tensor data during InferShape only after the data-dependent input is declared. If the GetInputTensor API is called without data-dependent input declared, the correct shape, format, and data type can be obtained from the tensor, but the actual tensor data address cannot be obtained (the obtained address is nullptr).
- When obtaining tensor_data from a tensor (GetData<int32_t> or GetData<int64_t>), ensure that the type of the obtained data is correct. Otherwise, the behavior is undefined.