Defining an Advanced Compute Operator (Conv2D)

The following takes the Conv2D operator as an example to describe how to define an advanced operator.

Conv2D operator prototype definition is as follows.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
REG_OP(Conv2D)
    .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16}))
    .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16}))
    .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
    .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
    .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_BF16}))
    .REQUIRED_ATTR(strides, ListInt)
    .REQUIRED_ATTR(pads, ListInt)
    .ATTR(dilations, ListInt, {1, 1, 1, 1})
    .ATTR(groups, Int, 1)
    .ATTR(data_format, String, "NHWC")
    .ATTR(offset_x, Int, 0)
    .OP_END_FACTORY_REG(Conv2D)
The prototype definition shows that the Conv2D operator has two required inputs (x and filter), two optional inputs (bias and offset_w), two required attributes (strides and pads), and four optional attributes (dilations, groups, data_format, and offset_x). The definition code of Conv2D is as follows.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
auto conv2d = op::Conv2D("Conv2d")
    // quant, conv_weight, conv_bias are three input nodes.
    .set_input_x(quant)
    .set_input_filter(conv_weight)
    .set_input_bias(conv_bias)
    .set_attr_strides({ 1, 1, 1, 1 })
    .set_attr_pads({ 0, 0, 0, 0 })
    .set_attr_dilations({ 1, 1, 1, 1 });

TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8);
TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8);
TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32);
TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32);
conv2d.update_input_desc_x(conv2d_input_desc_x);
conv2d.update_input_desc_filter(conv2d_input_desc_filter);
conv2d.update_input_desc_bias(conv2d_input_desc_bias);
conv2d.update_output_desc_y(conv2d_output_desc_y);

The major steps are as follows:

  1. Call an operator constructor, for example, call the Conv2D(const char* name) constructor to create an operator instance, and pass the operator name (for example, Conv2D) to this call.
    1
    auto conv2d1 = op::Conv2D("Conv2d")
    
  2. Call set_input_inputName to set operator inputs.
    1
    2
    3
        .set_input_x(data)
        .set_input_filter(conv_weight)
        .set_input_bias(conv_bias)
    

    data is the input node of the graph. It is constructed from the Data operator. For details, see Defining a Data Node (Data).

    conv_weight is a constant data constructed from the Const operator. For details, see Defining a Data Node (Const).

    conv_bias is a constant data constructed from the Const operator. For details, see Defining a Data Node (Const).

  3. Call set_attr_attributeName to set operator attributes.
    1
    2
    3
    .set_attr_strides({1, 1, 1, 1})       // Set the strides attribute values.
    .set_attr_pads({0, 0, 0, 0})          // Set the pads attribute values.
    .set_attr_dilations({1, 1, 1, 1});    // Set the dilations attribute values.
    
  4. For convolution operators such as Conv2D or operators that are sensitive to processing along the C axis, you are advised to set the format to NCHW or NHWC as required in the update_input_desc_inputName call.
    1
    2
    3
    4
    5
    6
    7
    8
    TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8);
    TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8);
    TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32);
    TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32);
    conv2d.update_input_desc_x(conv2d_input_desc_x);
    conv2d.update_input_desc_filter(conv2d_input_desc_filter);
    conv2d.update_input_desc_bias(conv2d_input_desc_bias);
    conv2d.update_output_desc_y(conv2d_output_desc_y);
    

    IR graph building does not support the following formats.

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    FORMAT_NC1HWC0
    FORMAT_FRACTAL_Z
    FORMAT_NC1C0HWPAD
    FORMAT_NHWC1C0
    FORMAT_FRACTAL_DECONV
    FORMAT_C1HWNC0
    FORMAT_FRACTAL_DECONV_TRANSPOSE
    FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS
    FORMAT_NC1HWC0_C04
    FORMAT_FRACTAL_Z_C04
    FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS
    FORMAT_NC1KHKWHWC0
    FORMAT_C1HWNCoC0
    FORMAT_FRACTAL_ZZ
    FORMAT_FRACTAL_NZ
    FORMAT_NDC1HWC0
    FORMAT_FRACTAL_Z_3D
    FORMAT_FRACTAL_Z_3D_TRANSPOSE
    FORMAT_FRACTAL_ZN_LSTM
    FORMAT_FRACTAL_Z_G
    FORMAT_ND_RNN_BIAS
    FORMAT_FRACTAL_ZN_RNN
    FORMAT_NYUV
    FORMAT_NYUV_A