文档
注册
评分
提单
论坛
小AI

refine_shapes_for_broadcast

函数功能

在满足广播规则的前提下,对输入shape1和shape2补维至相同长度,如果补维后的两个shape在某一维度上的值均为1,则舍弃该维度,然后对新的输入shape1和shape2在连续且广播方向相同的轴上执行合并操作,在连续非广播轴上执行合并操作。

  • 连续广播轴

    如上图所示:
    • shape1与shape2的连续广播轴为1、2,且两个轴的广播方向相同,都为从shape1广播到shape2,则shape1与shape2分别在轴1、轴2上执行fuse操作,得到shape1=(3, 1),shape2=(3,20)。
    • shape1ˊ与shape2ˊ的连续广播轴为1、2,但轴1与轴2的广播方向不同,则无法对轴1、轴2执行fuse操作,执行refine_shapes_for_broadcast操作后,shape1ˊ仍为(3, 1, 5),shape2ˊ仍为(3, 4, 1)。
  • 连续非广播轴

    如上图所示,shape1与shape2的连续非广播轴为0、1,则shape1与shape2分别在轴0、轴1上执行fuse操作,得到shape1=(12, 1),shape2=(12, 5)。

说明:输入shape1与shape2的长度可不相同,但经过补维度至相同长度后每一个维度需要满足广播操作的要求,即相同轴的维度值或者相同,或者其中一个值为1。

函数原型

def refine_shapes_for_broadcast(shape1, shape2)

参数说明

参数

说明

shape1

需要优化的shape1

shape2

需要优化的shape2

返回值说明

返回优化后的shape1与shape2。

约束说明

调用示例

from tbe.common.utils import shape_util
shape1, shape2= shape_util.refine_shapes_for_broadcast((1, 2, 3, 4, 1, 5, 6, 7), (2, 1, 1, 2, 1, 6, 7)) 

对shape2高维补1至和shape1长度相同,得到shape2为(1, 2, 1, 1, 2, 1, 6, 7)。

补维后,shape1和shape2的第0维均为1,则舍弃该维度,得到shape1为(2, 3, 4, 1, 5, 6, 7),shape2为(2, 1, 1, 2, 1, 6, 7)。

针对shape1,在连续的同方向广播轴1与2上进行合并,在连续的非广播轴5和6上进行合并,输出shape1 = (2, 12, 1, 5, 42)。

针对shape2,在连续的同方向广播轴1与2上进行合并,在连续的非广播轴5和6上进行合并,输出shape2 = (2, 1, 2, 1, 42)。

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词