Prototype Definition (REG_OP)

Prototype

An example of the function prototype definition is as follows.

REG_OP(xxx)
    .INPUT(x1, type)
    .OPTIONAL_INPUT(x2, type)
    .DYNAMIC_INPUT(x3, type)
    .OUTPUT(y1, type)
    .DYNAMIC_OUTPUT(y3, type)
    .REQUIRED_ATTR(a, type)
    .ATTR(b, type, default_value)
    .GRAPH(z1)
    .DYNAMIC_GRAPH(z2)
    .OP_END_FACTORY_REG(xxx)

Description

Defines the operator prototype, including inputs, outputs, and attributes, as well as data types of an operator.

After an operator prototype is defined, the prototype is registered with GE, including the input, output and attributes of the operators. In addition, class op::xxx is defined. You can include the prototype header file and instantiate the class to build the IR model as follows:

conv = op::Conv2D()
conv.set_input_x(feature_map_data)
conv.set_input_filter(weight_data)

For details about how to build a model, see Ascend Graph Developer Guide.

API Description

API Name

Description

Derivative APIs (for IR Model Building)

REG_OP(x)

Defines an operator prototype for operator type x.

REG_OP

.INPUT(x, type)

Defines an input (x) with its data type specified.

The data type is a TensorType enum, for example:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

See Description of Class TensorType for more information.

INPUT

.OPTIONAL_INPUT(x, type)

Defines an optional input (x) with its data type specified.

The data type is a TensorType enum, for example:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

See Description of Class TensorType for more information.

OPTIONAL_INPUT

.DYNAMIC_INPUT(x, type)

Defines a dynamic input (x) with its data type specified.

The data type is a TensorType enum, for example:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

See Description of Class TensorType for more information.

DYNAMIC_INPUT

.OUTPUT(x, type)

Defines an output (x) with its data type specified.

The data type is a TensorType enum, for example:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

See Description of Class TensorType for more information.

OUTPUT

.DYNAMIC_OUTPUT(x, type)

Defines a dynamic output (x) with its data type specified.

The data type is a TensorType enum, for example:
  • TensorType{DT_FLOAT}
  • TensorType({DT_FLOAT, DT_INT8})
  • TensorType::ALL()

See Description of Class TensorType for more information.

DYNAMIC_OUTPUT

.REQUIRED_ATTR(x, type)

Defines a required attribute (x) with its data type specified.

The data type is selected from:
  • Int: int64_t.
  • Float: float.
  • String: string.
  • Bool: bool.
  • Tensor: Tensor.
  • Type: Type enum.
  • NamedAttrs: NamedAttrs.
  • AscendString: AscendString.
  • ListInt: vector<int64_t>, indicating an int64_t list.
  • ListFloat: vector<float>, indicating a float list.
  • ListString: vector<string>, indicating a string list
  • ListBool: vector<bool list>.
  • ListTensor: vector<tensor list>.
  • Bytes: Buffer.
  • ListType: vector<type list>.
  • ListListInt: vector<vector<int64_t>>, a 2D list.
  • ListAscendString: vector<AscendString>, an AscendString list.
  • ListNamedAttrs: vector<NamedAttrs>, a NamedAttrs list.

REQUIRED_ATTR

.ATTR(x, type, default_value)

Defines an optional attribute with its name, data type, and default value specified.

Uses the default value specified here if the user does not set this attribute.

The data type is selected from:
  • Int: int64_t.
  • Float: float.
  • String: string.
  • Bool: bool.
  • Tensor: Tensor.
  • Type: Type enum.
  • NamedAttrs: NamedAttrs.
  • AscendString: AscendString.
  • ListInt: vector<int64_t list>.
  • ListFloat: vector<float list>.
  • ListString: vector<string list>.
  • ListBool: vector<bool list>.
  • ListTensor: vector<tensor list>.
  • Bytes: Buffer.
  • ListType: vector<type list>.
  • ListListInt: vector<vector<int64_t>>, a 2D list.
  • ListAscendString: vector<AscendString>, an AscendString list.
  • ListNamedAttrs: vector<NamedAttrs>, a NamedAttrs list.

Examples:

  • .ATTR(mode, Int, 1)
  • .ATTR(pad, ListInt, {0, 0, 0, 0})

ATTR

.GRAPH(z1)

Registers a subgraph of an operator. The input z1 is the subgraph name.

For example, the following calls can be used to register the subgraphs of operator If:

.GRAPH(then_branch) .GRAPH(else_branch)

Ensure that the subgraphs of an operator have unique names.

GRAPH

.DYNAMIC_GRAPH(z2)

Registers a subgraph of a dynamic operator. The input z2 is the subgraph name.

For example, the following call can be used to register the subgraph of operator Case:

.DYNAMIC_GRAPH(branches)

Ensure that the subgraphs of an operator have unique names.

DYNAMIC_GRAPH

.INFER_SHAPE_AND_TYPE()

Reserved for compatibility considerations and not used in the current version.

-

.OP_END_FACTORY_REG(x)

Works in pair with REG_OP to end the operator prototype definition.

Takes the same operator type (x) as REG_OP(x).

-

The OpReg &N() API in class OpReg is used to make OpReg calls in .** mode during operator registration, for example, .INPUT(x, type) and .OUTPUT(x, type).

Returns

None

Restrictions

  • The operator type passed to the REG_OP call must be globally unique.
  • The inputs of an operator must have unique names.
  • The outputs of an operator must have unique names.
  • The attributes of an operator must have unique names.

Call Examples and Related APIs

The following is an example of defining a dynamic-input operator prototype.

REG_OP(AddN)
    .DYNAMIC_INPUT(x, TensorType({NumberType(), DT_VARIANT}))
    .OUTPUT(y, TensorType({NumberType(), DT_VARIANT}))
    .REQUIRED_ATTR(N, Int)
    .OP_END_FACTORY_REG(AddN)

The following is an example of defining a multi-input operator prototype.

REG_OP(GreaterEqual)
    .INPUT(x1, TensorType::RealNumberType())
    .INPUT(x2, TensorType::RealNumberType())
    .OUTPUT(y, TensorType({DT_BOOL}))
    .OP_END_FACTORY_REG(GreaterEqual)

The following is an example of registering the operator prototype definition of a subgraph.

REG_OP(If)
    .INPUT(cond, TensorType::ALL())
    .DYNAMIC_INPUT(input, TensorType::ALL())
    .DYNAMIC_OUTPUT(output, TensorType::ALL())
    .GRAPH(then_branch)
    .GRAPH(else_branch)
    .OP_END_FACTORY_REG(If)