refine_shape_axes
函数功能
对axes(reduce轴列表)进行负值的优化,并将shape按照如下原则进行fuse操作:
对连续的reduce轴,与连续的非reduce轴分别进行fuse操作。
axes中每个axis的值的范围需要在[-rank, rank)之间,否则将抛出RuntimeError:Axis must between [-%d, %d).
函数原型
def refine_shape_axes(shape, axes)
参数说明
参数 |
说明 |
---|---|
shape |
需要优化的shape |
axes |
Reduce轴的列表。 |
返回值说明
- shape:进行fuse后的shape。
- axes:进行负值优化后的axes。
约束说明
无。
调用示例
from tbe.common.utils import shape_util shape, axes= shape_util.refine_shape_axes((2, 3, 4, 5, 6),(1,-3))
父主题: shape相关工具