BROADCAST_INFER

Description

Encapsulated macro of common functions, facilitating the development of the InferShape function. This function sets the output shape based on the shapes of two inputs. This macro sets the shape only, and does not set the dtype.

  • If the two inputs have the same shape, the output shape is set to the same shape.
  • If the two inputs have different shapes, the union of the two input shapes is used based on the broadcast policy.

    For example, if the two input shapes are (1, 2, 3, 4) and (3, 1, 3, 4), the macro sets the output shape of the operator to (3, 2, 3, 4).

Prototype

BROADCAST_INFER(in1_name, in2_name, out_name)

A call to the preceding function automatically calls the following function:

graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
                           const function<vector<int64_t>()> &get_in2_shape,
                           const function<void(const std::vector<int64_t> &y_shape)> &set_out_shape);

Restrictions

None

Parameters

Parameter

Input/Output

Description

in1_name

Input

First input of the operator.

in2_name

Input

Second input of the operator.

out_name

Input

Operator output.

Returns

Success or failure

Example

IMPLEMT_INFERFUNC(RightShift, RightShiftInfer) {
  DataType type = op.GetInputDesc("x").GetDataType();
  SET_OUTPUT_TYPE(op, "z", type);
  return BROADCAST_INFER("x", "y", "z")(op);
}