Shape Derivation and Verification

Overview

The .cc IR definition file offers the following functions:

  • Verifies operator parameters, implementing program robustness and improving locating efficiency.
  • Infers the output tensor description of the operator based on the input tensor description, operator logic, and operator attributes. The output tensor description includes the tensor shape, data type, and format. In this way, all tensors can be statically allocated with memory during the preparation for graph construction, thereby avoiding overhead caused by dynamic memory allocation.

When implementing the Verify and InferShape methods in op_proto/op_name.cc, you do not need to declare the methods.

Dependency Header Files

#include "op_name.h"
#include <vector>
#include <string>
Table 1 Description of header files

Header File

Directory

Description

op_name.h

IR header file implemented in Operator IR Registration

The object op of the Operator class registered in this file or the subclass op derived from the Operator class can be called once this header file is included.

string

C++ standard library

string objects can be used and APIs of class string can be called once this header file is included.

vector

C++ standard library

vector templates can be used and APIs of class vector can be called once this header file is included.

Implementation of the InferShape Function

The following APIs can be used to define InferShape in the operator IR:

IMPLEMT_COMMON_INFERFUNC(func_name)

This API automatically generates an object of the Operator class. You can call APIs of Operator directly to implement InferShape. In the preceding function, func_name is user-defined.

  • The following is an example of assigning the input description to the output description.
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    IMPLEMT_COMMON_INFERFUNC(SoftmaxInferShape)
    {
      TensorDesc tensordesc_output = op.GetOutputDescByName("y");
    
      tensordesc_output.SetShape(op.GetInputDescByName("x").GetShape());
      tensordesc_output.SetDataType(op.GetInputDescByName("x").GetDataType());
      tensordesc_output.SetFormat(op.GetInputDescByName("x").GetFormat());
    
      (void)op.UpdateOutputDesc("y", tensordesc_output);
      return GRAPH_SUCCESS;
    }
    
  • The output description is computed based on the operator logic. The following is an example.
    IMPLEMT_COMMON_INFERFUNC(NotEqualInferShape) {
      Shape x_shape = op.GetInputDescByName("x1").GetShape();
      Shape y_shape = op.GetInputDescByName("x2").GetShape();
      TensorDesc td = op.GetOutputDescByName("y");
      std::vector<int64_t> dims_x = x_shape.GetDims();
      std::vector<int64_t> dims_y = y_shape.GetDims();
      if (dims_x.size() < dims_y.size()) {
        std::vector<int64_t> dims_tmp = dims_x;
        dims_x = dims_y;
        dims_y = dims_tmp;
      }
      if (dims_x.size() != dims_y.size()) {
        int dec = dims_x.size() - dims_y.size();
        for (int i = 0; i < dec; i++) {
          dims_y.insert(dims_y.begin(), (int64_t)1);
        }
      }
      std::vector<int64_t> dim_vec;
      for (size_t i = 0; i < dims_x.size(); i++) {
        if ((dims_x[i] != dims_y[i]) && (dims_x[i] != 1) && (dims_y[i] != 1)) {
          printf ("The %s op dimensions does not match the broadcast rule(%lu %lu).", op.GetName().c_str(), dims_x[i], dims_y[i]);
        }
        int64_t dims = dims_x[i] > dims_y[i] ? dims_x[i] : dims_y[i];
        dim_vec.push_back(dims);
      }
      td.SetShape(ge::Shape(dim_vec));
      td.SetDataType(DT_BOOL);
      (void)op.UpdateOutputDesc("y", td);
      return GRAPH_SUCCESS;
    }
  • If a dynamic input/output exists, process the inputs and outputs in lists. An example is as follows.
    IMPLEMT_COMMON_INFERFUNC(BatchInfer)
    {
        for (size_t i = 0; i < op.GetInputsSize(); ++i) {
            Shape out_shapes;
            if (ReplaceDim(op.GetInputDesc(i).GetShape(), 
                    0, ge::UNKNOWN_DIM, out_shapes, op.GetName().c_str()) == GRAPH_FAILED) {
                return GRAPH_FAILED;
            }
            auto y_tensor_type = op.GetDynamicInputDesc("x_tensors", i).GetDataType();
            TensorDesc output_desc = op.GetDynamicOutputDesc("y_tensors", i);
            output_desc.SetShape(out_shapes);
            output_desc.SetDataType(y_tensor_type);
            op.UpdateDynamicOutputDesc("y_tensors", i, output_desc);
        }
    
        Shape scalar_shape;
        Scalar(scalar_shape);
        TensorDesc y_desc = op.GetOutputDesc("y_id");
        y_desc.SetShape(scalar_shape);
        y_desc.SetDataType(DT_INT64);
        op.UpdateOutputDesc("y_id", y_desc);
    
        std::vector<int64_t> dims = { ge::UNKNOWN_DIM, 3 };
        TensorDesc output_desc_batch_index = op.GetOutputDesc("y_index");
        output_desc_batch_index.SetShape(Shape(dims));
        output_desc_batch_index.SetDataType(DT_INT64);
        op.UpdateOutputDesc("y_index", output_desc_batch_index);
        return GRAPH_SUCCESS;
    }
    // Define the ReplaceDim function as follows.
    graphStatus ReplaceDim(const Shape& s, 
                           int64_t dim_index_in, 
                           int64_t new_dim, 
                           Shape& out, 
                           const char* op_name) {
      if(shape.GetDims() == UNKNOWN_RANK) {
        out = Shape(ge::UNKNOWN_SHAPE);
        return GRAPH_SUCCESS;
      }
      int64_t dim_index = dim_index_in;
      if (dim_index < 0) {
        dim_index = (int64_t)s.GetDimNum() + dim_index;
      }
      std::vector<int64_t> dims = s.GetDims();
      dims[dim_index] = new_dim;
      out = Shape(dims);
      return GRAPH_SUCCESS;
    }

Implementation of the Verify Function

The following APIs are used to implement the Verify function of an operator:

IMPLEMT_VERIFIER (OpType, func_name)

The input OpType is a subclass derived from the class Operator. An op object of this subclass is automatically generated. You can use the member functions of the subclass to obtain the operator attributes. For details about the member functions of the op object, see Operator.

  • OpType: type of the custom operator.
  • func_name: user-defined name of the Verify function.

The Verify function is used to verify the internal association relationship of operators. For example, for a multi-input operator, the data types of tensors must be the same. In this case, the data types of the inputs need to be verified. In other cases, the data type does not need to be verified.

The implementation example is as follows.

IMPLEMT_VERIFIER(Pow, PowVerify) {
  DataType input_type_x = op.GetInputDescByName("x").GetDataType();
  DataType input_type_y = op.GetInputDescByName("y").GetDataType();
  if (input_type_x != input_type_y) {
    return GRAPH_FAILED;
  }
  return GRAPH_SUCCESS;
}

Registration of the Infershape and Verify Functions

Launch the InferShape registration macro and Verify registration macro to register the InferShape and Verify functions, as shown in the following:

COMMON_INFER_FUNC_REG(OpType, func_name);  
VERIFY_FUNC_REG(OpType, func_name);

The value of func_name is the same as that of func_name in Implementation of the InferShape Function and Implementation of the Verify Function.