refine_shapes_for_broadcast

函数功能

对输入shape1与shape2进行broadcast操作,然后对broadcast操作后的shape1与shape2分别在连续的非broadcast轴和连续的broadcast轴上执行fuse操作。

函数原型

def refine_shapes_for_broadcast(shape1, shape2)

参数说明

参数

说明

shape1

需要优化的shape1

shape2

需要优化的shape2

返回值说明

返回优化后的shape1与shape2。

约束说明

调用示例

from te.utils import shape_util
shape1, shape2= shape_util.refine_shapes_for_broadcast((3, 4, 5, 6, 7), (3, 1, 1, 6, 1)) 

对shape1与shape2执行brodcast操作,执行broadcast操作的轴为(1, 2, 4)。

然后对两个输入shape分别在连续的broadcast轴与连续的非broadcast轴上执行fuse操作。

针对shape1,在连续的轴1与轴2上执行fuse操作,输出shape1=(3, 20, 6, 7)。

针对shape2,在连续的轴1与轴2上执行fuse操作,输出shape2=(3, 2, 6, 1)