refine_shapes_for_broadcast

Description

Implements padding on the inputs to make shape1 and shape2 meet the broadcast requirements, that is, to make the inputs have the same rank. If the dimension values of the same axis are both 1 after padding, the dimensions are removed before the subsequent fuse operation. Then, this API fuses adjacent broadcast axes with identical broadcast direction into one axis or adjacent unbroadcast axes into one axis in each shape.

  • For shapes with adjacent broadcast axes

    As shown in the preceding figure:
    • Given input shape1 and shape2, the adjacent broadcast axes are axis 1 and axis 2, and the two axis pairs have identical broadcast direction (from shape1 to shape2). As such, axis 1 and axis 2 of shape1 are fused into one axis, and the refined shape1 is (3, 1). Likewise, axis 1 and axis 2 of shape2 are fused into one axis, and the refined shape2 is (3, 20).
    • Given shape1' and shape2', the adjacent broadcast axes are axis 1 and axis 2; however, the two axis pairs have different broadcast directions. As such, there are no adjacent axes to fuse. After the refine_shapes_for_broadcast call, the returned shape1' and shape2' are still (3, 1, 5) and (3, 4, 1), respectively.
  • For shapes with adjacent unbroadcast axes

    As shown in the preceding figure, for both shape1 and shape2, the adjacent unbroadcast axes are axis 0 and axis 1. As such, axis 0 and axis 1 of shape1 are fused into one axis, and the refined shape1 is (12, 1). Likewise, axis 0 and axis 1 of shape2 are fused into one axis, and the refined shape2 is (12, 5).

Note: The lengths of the input shape1 and shape2 can be different. However, after the dimensions of two inputs are padded to the same length, the inputs must meet the requirements for broadcasting. That is, the corresponding dimensions either are equal or one of them needs to be 1.

Prototype

def refine_shapes_for_broadcast(shape1, shape2)

Parameters

Parameter

Description

shape1

shape1 to refine.

shape2

shape2 to refine.

Returns

Refined shape1 and shape2

Restrictions

None

Example

from tbe.common.utils import shape_util
shape1, shape2= shape_util.refine_shapes_for_broadcast((1, 2, 3, 4, 1, 5, 6, 7), (2, 1, 1, 2, 1, 6, 7)) 

Pad shape2 to the identical rank to shape1, resulting in shape2 (1, 2, 1, 1, 2, 1, 6, 7).

After dimension padding, dimension values of axis 0 of both shape1 and shape2 is 1. In this case, the dimensions are removed, resulting in shape1 (2, 3, 4, 1, 5, 6, 7) and shape2 (2, 1, 1, 2, 1, 6, 7).

For shape1, adjacent broadcast axes (axes 1 and 2) with identical broadcast direction are fused, and adjacent unbroadcast axes (axes 5 and 6) are fused, resulting in shape1 (2, 12, 1, 5, 42).

For shape2, adjacent broadcast axes (axes 1 and 2) with identical broadcast direction are fused, and adjacent unbroadcast axes (axes 5 and 6) are fused, resulting in shape2 (2, 1, 2, 1, 42).