broadcast_shapes

Description

Broadcasts shape1 and shape2, and returns the result shape1, result shape2, and broadcast shape.

The broadcast shape comes from the larger ones between the dimension pairs of shape1 and shape2.

If shape1 and shape2 have uncompatible ranks, leading dimensions sized 1 will be added to the shorter shape before the broadcast operation.

If shape1 and shape2 have unequal dimension pairs, and neither of them is 1, the broadcast rules are not met and the following exception is thrown: RuntimeError "input shapes not match!"

Prototype

def broadcast_shapes(shape1, shape2, op_name=OP_NAME, param_name_input1='', param_name_input2='')

Parameters

Parameter

Description

shape1

Shape to broadcast.

shape2

Shape to broadcast.

op_name

Operator name, used as an additional prompt during information display. Defaults to NULL.

param_name_input1

input1 name, used as an additional prompt during information display. Defaults to NULL.

param_name_input2

input2 name, used as an additional prompt during information display. Defaults to NULL.

Returns

Result shape1, result shape2, and broadcast shape

Restrictions

None

Example

from tbe.common.utils import shape_util
shape1, shape2, shape3= shape_util.broadcast_shapes((3, 4, 1), (1,5)) 

Result shapes:

shape1 = (3, 4, 1)

shape2 = (1, 1, 5)

shape3 = (3, 4, 5)