BROADCAST_INFER

Applicability

Product

Supported or Not

Atlas A3 training products/Atlas A3 inference products

Atlas A2 training products/Atlas A2 inference products

Atlas 200I/500 A2 inference products

Atlas inference products

Atlas training products

Header File

#include <graph/operator_reg.h>

Function Usage

Encapsulates the macro of common functions, facilitating the development of the InferShape function. This function sets the output shape based on the shapes of two inputs. This macro sets the shape only, and does not set the data type.

  • If the two inputs have the same shape, the output shape is set to the same shape.
  • If the two inputs have different shapes, the union of the two input shapes is used based on the broadcast policy.

    For example, if the two input shapes are (1, 2, 3, 4) and (3, 1, 3, 4), the macro sets the output shape of the operator to (3, 2, 3, 4).

Prototype

1
BROADCAST_INFER(in1_name, in2_name, out_name)

A call to the preceding function automatically calls the following function:

1
2
3
graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
                           const function<vector<int64_t>()> &get_in2_shape,
                           const function<void(const std::vector<int64_t> &y_shape)> &set_out_shape);

Parameters

Parameter

Input/Output

Description

in1_name

Input

First input of the operator.

in2_name

Input

Second input of the operator.

out_name

Input

Operator output.

Returns

Success or failure.

Constraints

None

Examples

1
2
3
4
5
IMPLEMT_INFERFUNC(RightShift, RightShiftInfer) {
  DataType type = op.GetInputDesc("x").GetDataType();
  SET_OUTPUT_TYPE(op, "z", type);
  return BROADCAST_INFER("x", "y", "z")(op);
}