对输入shape1与shape2进行broadcast操作,然后对broadcast操作后的shape1与shape2分别在连续的非broadcast轴和连续的broadcast轴上执行fuse操作。
def refine_shapes_for_broadcast(shape1, shape2)
参数 |
说明 |
---|---|
shape1 |
需要优化的shape1 |
shape2 |
需要优化的shape2 |
返回优化后的shape1与shape2。
无
from te.utils import shape_util shape1, shape2= shape_util.refine_shapes_for_broadcast((3, 4, 5, 6, 7), (3, 1, 1, 6, 1))
对shape1与shape2执行brodcast操作,执行broadcast操作的轴为(1, 2, 4)。
然后对两个输入shape分别在连续的broadcast轴与连续的非broadcast轴上执行fuse操作。
针对shape1,在连续的轴1与轴2上执行fuse操作,输出shape1=(3, 20, 6, 7)。
针对shape2,在连续的轴1与轴2上执行fuse操作,输出shape2=(3, 2, 6, 1)