refine_shapes_for_broadcast
函数功能
在满足广播规则的前提下,对输入shape1和shape2补维至相同长度,如果补维后的两个shape在某一维度上的值均为1,则舍弃该维度,然后对新的输入shape1和shape2在连续且广播方向相同的轴上执行合并操作,在连续非广播轴上执行合并操作。
- 连续广播轴
如上图所示:
- shape1与shape2的连续广播轴为1、2,且两个轴的广播方向相同,都为从shape1广播到shape2,则shape1与shape2分别在轴1、轴2上执行fuse操作,得到shape1=(3, 1),shape2=(3,20)。
- shape1ˊ与shape2ˊ的连续广播轴为1、2,但轴1与轴2的广播方向不同,则无法对轴1、轴2执行fuse操作,执行refine_shapes_for_broadcast操作后,shape1ˊ仍为(3, 1, 5),shape2ˊ仍为(3, 4, 1)。
- 连续非广播轴
如上图所示,shape1与shape2的连续非广播轴为0、1,则shape1与shape2分别在轴0、轴1上执行fuse操作,得到shape1=(12, 1),shape2=(12, 5)。
说明:输入shape1与shape2的长度可不相同,但经过补维度至相同长度后每一个维度需要满足广播操作的要求,即相同轴的维度值或者相同,或者其中一个值为1。
函数原型
def refine_shapes_for_broadcast(shape1, shape2)
参数说明
参数 |
说明 |
---|---|
shape1 |
需要优化的shape1 |
shape2 |
需要优化的shape2 |
返回值说明
返回优化后的shape1与shape2。
约束说明
无
调用示例
from tbe.common.utils import shape_util shape1, shape2= shape_util.refine_shapes_for_broadcast((1, 2, 3, 4, 1, 5, 6, 7), (2, 1, 1, 2, 1, 6, 7))
对shape2高维补1至和shape1长度相同,得到shape2为(1, 2, 1, 1, 2, 1, 6, 7)。
补维后,shape1和shape2的第0维均为1,则舍弃该维度,得到shape1为(2, 3, 4, 1, 5, 6, 7),shape2为(2, 1, 1, 2, 1, 6, 7)。
针对shape1,在连续的同方向广播轴1与2上进行合并,在连续的非广播轴5和6上进行合并,输出shape1 = (2, 12, 1, 5, 42)。
针对shape2,在连续的同方向广播轴1与2上进行合并,在连续的非广播轴5和6上进行合并,输出shape2 = (2, 1, 2, 1, 42)。
父主题: shape相关工具