昇腾社区首页
中文
注册

定义复杂计算算子(Conv2D)

下面以一个较复杂的Conv2D为例,介绍如何进行算子定义。

Conv2D算子原型定义:

REG_OP(Conv2D)
    .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16}))
    .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_BF16}))
    .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
    .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
    .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_BF16}))
    .REQUIRED_ATTR(strides, ListInt)
    .REQUIRED_ATTR(pads, ListInt)
    .ATTR(dilations, ListInt, {1, 1, 1, 1})
    .ATTR(groups, Int, 1)
    .ATTR(data_format, String, "NHWC")
    .ATTR(offset_x, Int, 0)
    .OP_END_FACTORY_REG(Conv2D)

从Conv2D算子原型定义可以看到,Conv2D算子包括:两个必选输入x和filter,两个可选输入bias和offset_w,两个必选属性strides、pads,四个可选属性dilations、groups、data_format、offset_x。则Conv2D算子定义的代码为:

auto conv2d = op::Conv2D("Conv2d")
    .set_input_x(quant)
    .set_input_filter(conv_weight)
    .set_input_bias(conv_bias)
    .set_attr_strides({ 1, 1, 1, 1 })
    .set_attr_pads({ 0, 0, 0, 0 })
    .set_attr_dilations({ 1, 1, 1, 1 });

TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8);
TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8);
TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32);
TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32);
conv2d.update_input_desc_x(conv2d_input_desc_x);
conv2d.update_input_desc_filter(conv2d_input_desc_filter);
conv2d.update_input_desc_bias(conv2d_input_desc_bias);
conv2d.update_output_desc_y(conv2d_output_desc_y);

主要过程为:

  1. 调用算子类型构造函数,例如“Conv2D(const char* name)”创建算子实例,并传算子名称(例如Conv2d)作为入参。
    auto conv2d1 = op::Conv2D("Conv2d")
  2. 调用“set_input_输入名称”接口设置算子的输入。
        .set_input_x(data)
        .set_input_filter(conv_weight)
        .set_input_bias(conv_bias)

    data为整个graph的输入节点,通过Data算子构造,具体请参考定义数据节点(Data)

    conv_weight为常量数据,通过Const算子构造,具体请参考定义数据节点(Const)

    conv_bias为常量数据,通过Const算子构造,具体请参考定义数据节点(Const)

  3. 调用“set_attr_属性名称”接口设置算子的属性。
    .set_attr_strides({1, 1, 1, 1})       //设置strides属性值
    .set_attr_pads({0, 0, 0, 0})          //设置pads属性值
    .set_attr_dilations({1, 1, 1, 1});    //设置dilations属性值
  4. 对于Conv2D等卷积类或对C轴处理敏感的算子,建议通过“update_input_desc_输入名称”接口将Format信息设置为NCHW或者NHWC等,具体和用户需要处理的Format格式保持一致。
    TensorDesc conv2d_input_desc_x(ge::Shape(), FORMAT_NCHW, DT_INT8);
    TensorDesc conv2d_input_desc_filter(ge::Shape(), FORMAT_NCHW, DT_INT8);
    TensorDesc conv2d_input_desc_bias(ge::Shape(), FORMAT_NCHW, DT_INT32);
    TensorDesc conv2d_output_desc_y(ge::Shape(), FORMAT_NCHW, DT_INT32);
    conv2d.update_input_desc_x(conv2d_input_desc_x);
    conv2d.update_input_desc_filter(conv2d_input_desc_filter);
    conv2d.update_input_desc_bias(conv2d_input_desc_bias);
    conv2d.update_output_desc_y(conv2d_output_desc_y);

    IR构图不支持输入以下FORMAT:

    NC1HWC0
    FRACTAL_Z
    NC1C0HWPAD
    NHWC1C0
    FRACTAL_DECONV
    C1HWNC0
    FRACTAL_DECONV_TRANSPOSE
    FRACTAL_DECONV_SP_STRIDE_TRANS
    NC1HWC0_C04
    FRACTAL_Z_C04
    FRACTAL_DECONV_SP_STRIDE8_TRANS
    NC1KHKWHWC0
    C1HWNCoC0
    FRACTAL_ZZ
    FRACTAL_NZ
    NDC1HWC0
    FORMAT_FRACTAL_Z_3D
    FORMAT_FRACTAL_Z_3D_TRANSPOSE
    FORMAT_FRACTAL_ZN_LSTM
    FORMAT_FRACTAL_Z_G
    FORMAT_ND_RNN_BIAS
    FORMAT_FRACTAL_ZN_RNN
    FORMAT_NYUV
    FORMAT_NYUV_A