Prototype Definition (REG_OP)

Applicability

Product

Supported or Not

Atlas A3 training products / Atlas A3 inference products

Atlas A2 training products / Atlas A2 inference products

Atlas 200I/500 A2 inference products

Atlas inference products

Atlas training products

Header File

#include <graph/operator_reg.h>

Function Usage

Defines the operator prototype, including inputs, outputs, attributes, and the corresponding data types.

After an operator prototype is defined, the prototype is registered with the GE, including the inputs, outputs, and attributes of the operators. In addition, class op::xxx is defined. You can include the prototype header file and instantiate the class to build the IR model as follows:

1
2
3
conv = op::Conv2D()
conv.set_input_x(feature_map_data)
conv.set_input_filter(weight_data)

For details about how to build a model, see Graph Mode Development Guide.

Prototype

An example of defining the function prototype is as follows.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
REG_OP(xxx)
    .INPUT(x1, type)
    .OPTIONAL_INPUT(x2, type)
    .DYNAMIC_INPUT(x3, type)
    .OUTPUT(y1, type)
    .DYNAMIC_OUTPUT(y3, type)
    .REQUIRED_ATTR(a, type)
    .ATTR(b, type, default_value)
    .GRAPH(z1)
    .DYNAMIC_GRAPH(z2)
    .OP_END_FACTORY_REG(xxx)

Description

API Name

Description

Derivative API (for IR Model Build)

REG_OP(xxx)

Defines an operator prototype for operator type xxx.

REG_OP

.INPUT(x, type)

Defines an input (x) and its data type (type).

The data type is of class TensorType. Examples:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

For details about the TensorType class, see TensorType.

INPUT

.OPTIONAL_INPUT(x, type)

Defines an optional input (x) and its data type (type).

The data type is of class TensorType. Examples:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

For details about the TensorType class, see TensorType.

OPTIONAL_INPUT

.DYNAMIC_INPUT(x, type)

Defines a dynamic input (x) and its data type (type).

The data type is of class TensorType. Examples:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

For details about the TensorType class, see TensorType.

DYNAMIC_INPUT

.OUTPUT(x, type)

Defines an output (x) and its data type (type).

The data type is of class TensorType. Examples:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

For details about the TensorType class, see TensorType.

OUTPUT

.DYNAMIC_OUTPUT(x, type)

Defines a dynamic output (x) and its data type (type).

The data type is of class TensorType. Examples:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

For details about the TensorType class, see TensorType.

DYNAMIC_OUTPUT

.REQUIRED_ATTR(x, type)

Defines a required attribute (x) and its data type (type).

The options of type are as follows:
  • Int: int64_t.
  • Float: float.
  • String: string.
  • Bool: bool.
  • Tensor: Tensor.
  • Type: Type enum.
  • NamedAttrs: NamedAttrs.
  • AscendString: AscendString.
  • ListInt: vector<int64_t>, indicating an int64_t list.
  • ListFloat: vector<float>, indicating a float list.
  • ListString: vector<string>, indicating a string list.
  • ListBool: vector<bool>, indicating a bool list.
  • ListTensor: vector<Tensor>, indicating a tensor list.
  • Bytes: Buffer.
  • ListType: vector<Type>, indicating a type list.
  • ListListInt: vector<vector<int64_t>>, indicating a 2D list.
  • ListAscendString: vector<AscendString>, indicating an AscendString list.
  • ListNamedAttrs: vector<NamedAttrs>, indicating a NamedAttrs list.

REQUIRED_ATTR

.ATTR(x, type, default_value)

Defines the name, type, and default value of an optional attribute.

The default value specified here is used if you do not set this attribute.

The options of type are as follows:
  • Int: int64_t.
  • Float: float.
  • String: string.
  • Bool: bool.
  • Tensor: Tensor.
  • Type: Type enum.
  • NamedAttrs: NamedAttrs.
  • AscendString: AscendString.
  • ListInt: vector<int64_t>, indicating an int64_t list.
  • ListFloat: vector<float>, indicating a float list.
  • ListString: vector<string>, indicating a string list.
  • ListBool: vector<bool>, indicating a bool list.
  • ListTensor: vector<Tensor>, indicating a tensor list.
  • Bytes: Buffer.
  • ListType: vector<Type>, indicating a type list.
  • ListListInt: vector<vector<int64_t>>, indicating a 2D list.
  • ListAscendString: vector<AscendString>, indicating an AscendString list.
  • ListNamedAttrs: vector<NamedAttrs>, indicating a NamedAttrs list.

Examples:

  • .ATTR(mode, Int, 1)
  • .ATTR(pad, ListInt, {0, 0, 0, 0})

ATTR

.GRAPH(z1)

Registers a subgraph of an operator. The input z1 is the subgraph name.

For example, the following calls can be used to register the subgraphs of operator If:

.GRAPH(then_branch) .GRAPH(else_branch)

Ensure that each subgraph of an operator has a unique name.

GRAPH

.DYNAMIC_GRAPH(z2)

Registers a subgraph of a dynamic operator. The input z2 is the subgraph name.

For example, the following call can be used to register the subgraph of operator Case:

.DYNAMIC_GRAPH(branches)

Ensure that each subgraph of an operator has a unique name.

DYNAMIC_GRAPH

.INFER_SHAPE_AND_TYPE()

Reserved for compatibility considerations and not used in the current version.

-

.OP_END_FACTORY_REG(x)

Works in pair with REG_OP to end the operator prototype definition.

The operator type (x) must be the same as that in REG_OP(x).

-

The OpReg &N() API in class OpReg is used to make OpReg calls, such as .INPUT(x, type) and .OUTPUT(x, type), in .** mode during operator registration.

Returns

None

Constraints

  • The operator type passed to the REG_OP call must be globally unique.
  • The inputs of an operator must have unique names.
  • The outputs of an operator must have unique names.
  • The attributes of an operator must have unique names.

Call Examples and Related APIs

The following is an example of defining a dynamic-input operator prototype.

1
2
3
4
5
REG_OP(AddN)
    .DYNAMIC_INPUT(x, TensorType({NumberType(), DT_VARIANT}))
    .OUTPUT(y, TensorType({NumberType(), DT_VARIANT}))
    .REQUIRED_ATTR(N, Int)
    .OP_END_FACTORY_REG(AddN)

The following is an example of defining a multi-input operator prototype.

1
2
3
4
5
REG_OP(GreaterEqual)
    .INPUT(x1, TensorType::RealNumberType())
    .INPUT(x2, TensorType::RealNumberType())
    .OUTPUT(y, TensorType({DT_BOOL}))
    .OP_END_FACTORY_REG(GreaterEqual)

The following is an example of defining the operator prototype for subgraph registration.

1
2
3
4
5
6
7
REG_OP(If)
    .INPUT(cond, TensorType::ALL())
    .DYNAMIC_INPUT(input, TensorType::ALL())
    .DYNAMIC_OUTPUT(output, TensorType::ALL())
    .GRAPH(then_branch)
    .GRAPH(else_branch)
    .OP_END_FACTORY_REG(If)