broadcast_shapes
Description
Broadcasts shape1 and shape2, and returns the result shape1, result shape2, and broadcast shape.
The broadcast shape comes from the larger ones between the dimension pairs of shape1 and shape2.
If shape1 and shape2 have uncompatible ranks, leading dimensions sized 1 will be added to the shorter shape before the broadcast operation.
If shape1 and shape2 have unequal dimension pairs, and neither of them is 1, the broadcast rules are not met and the following exception is thrown: RuntimeError "input shapes not match!"
Prototype
def broadcast_shapes(shape1, shape2, op_name=OP_NAME, param_name_input1='', param_name_input2='')
Parameters
Parameter |
Description |
|---|---|
shape1 |
Shape to broadcast. |
shape2 |
Shape to broadcast. |
op_name |
Operator name, used as an additional prompt during information display. Defaults to NULL. |
param_name_input1 |
input1 name, used as an additional prompt during information display. Defaults to NULL. |
param_name_input2 |
input2 name, used as an additional prompt during information display. Defaults to NULL. |
Returns
Result shape1, result shape2, and broadcast shape
Restrictions
None
Example
from tbe.common.utils import shape_util shape1, shape2, shape3= shape_util.broadcast_shapes((3, 4, 1), (1,5))
Result shapes:
shape1 = (3, 4, 1)
shape2 = (1, 1, 5)
shape3 = (3, 4, 5)