refine_shape_axes

Description

Refines negative axes to reduce to positive, and fuses the axes according to the following rules:

Fuses the neighboring axes to reduce and neighboring axes not to reduce respectively.

If any element in axes is not within the range of [–rank, +rank), the message "RuntimeError: Axis must between [–%d, %d)" is thrown.

Prototype

def refine_shape_axes(shape, axes)

Parameters

Parameter

Description

shape

Shape to refine.

axes

Axes to reduce.

Returns

  • The fused shape
  • The refined axes

Restrictions

None

Example

from tbe.common.utils import shape_util
shape, axes= shape_util.refine_shape_axes((2, 3, 4, 5, 6),(1,-3)) 
  1. Refines (1, –3) to positive axes. For rank 5 with axes (2, 3, 4, 5, 6), the negative axis –3 corresponds to axis 2. Therefore, the refined axes are (1, 2).
  2. Fuses the shape (2, 3, 4, 5, 6).

    Fuses the neighboring axes 1 and 2 to reduce and fuses the neighboring axes 3 and 4 not to reduce, respectively. The fused shape is (2, 12, 30), and axes (1, 2) are fused into axis 1.

    The result shape and axes are as follows:

    shape=(2, 12, 30)

    axes=(1)