Defining an Advanced Compute Operator (Conv2D)
The following takes the Conv2D operator as an example to describe how to define an advanced operator.
Conv2D operator prototype definition is as follows.
1 2 3 4 5 6 7 8 9 10 11 12 13 | REG_OP(Conv2D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16})) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_BF16})) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .ATTR(groups, Int, 1) .ATTR(data_format, String, "NHWC") .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv2D) |
The prototype definition shows that the Conv2D operator has two required inputs (x and filter), two optional inputs (bias and offset_w), two required attributes (strides and pads), and four optional attributes (dilations, groups, data_format, and offset_x). The definition code of Conv2D is as follows.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | auto conv2d = op::Conv2D("Conv2d") // quant, conv_weight, conv_bias are three input nodes. .set_input_x(quant) .set_input_filter(conv_weight) .set_input_bias(conv_bias) .set_attr_strides({ 1, 1, 1, 1 }) .set_attr_pads({ 0, 0, 0, 0 }) .set_attr_dilations({ 1, 1, 1, 1 }); TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32); TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32); conv2d.update_input_desc_x(conv2d_input_desc_x); conv2d.update_input_desc_filter(conv2d_input_desc_filter); conv2d.update_input_desc_bias(conv2d_input_desc_bias); conv2d.update_output_desc_y(conv2d_output_desc_y); |
The major steps are as follows:
- Call an operator constructor, for example, call the Conv2D(const char* name) constructor to create an operator instance, and pass the operator name (for example, Conv2D) to this call.
1auto conv2d1 = op::Conv2D("Conv2d")
- Call set_input_inputName to set operator inputs.
1 2 3
.set_input_x(data) .set_input_filter(conv_weight) .set_input_bias(conv_bias)
data is the input node of the graph. It is constructed from the Data operator. For details, see Defining a Data Node (Data).
conv_weight is a constant data constructed from the Const operator. For details, see Defining a Data Node (Const).
conv_bias is a constant data constructed from the Const operator. For details, see Defining a Data Node (Const).
- Call set_attr_attributeName to set operator attributes.
1 2 3
.set_attr_strides({1, 1, 1, 1}) // Set the strides attribute values. .set_attr_pads({0, 0, 0, 0}) // Set the pads attribute values. .set_attr_dilations({1, 1, 1, 1}); // Set the dilations attribute values.
- For convolution operators such as Conv2D or operators that are sensitive to processing along the C axis, you are advised to set the format to NCHW or NHWC as required in the update_input_desc_inputName call.
1 2 3 4 5 6 7 8
TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8); TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32); TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32); conv2d.update_input_desc_x(conv2d_input_desc_x); conv2d.update_input_desc_filter(conv2d_input_desc_filter); conv2d.update_input_desc_bias(conv2d_input_desc_bias); conv2d.update_output_desc_y(conv2d_output_desc_y);
IR graph building does not support the following formats.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
FORMAT_NC1HWC0 FORMAT_FRACTAL_Z FORMAT_NC1C0HWPAD FORMAT_NHWC1C0 FORMAT_FRACTAL_DECONV FORMAT_C1HWNC0 FORMAT_FRACTAL_DECONV_TRANSPOSE FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS FORMAT_NC1HWC0_C04 FORMAT_FRACTAL_Z_C04 FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS FORMAT_NC1KHKWHWC0 FORMAT_C1HWNCoC0 FORMAT_FRACTAL_ZZ FORMAT_FRACTAL_NZ FORMAT_NDC1HWC0 FORMAT_FRACTAL_Z_3D FORMAT_FRACTAL_Z_3D_TRANSPOSE FORMAT_FRACTAL_ZN_LSTM FORMAT_FRACTAL_Z_G FORMAT_ND_RNN_BIAS FORMAT_FRACTAL_ZN_RNN FORMAT_NYUV FORMAT_NYUV_A
Parent topic: Operator Expression