文档
注册
评分
提单
论坛
小AI

broadcast_shapes

函数功能

将两个输入shape:shape1与shape2进行广播,并返回shape1、shape2与broadcast shape。

broadcast shape取shape1与shape2每个维度的大值。

若shape1与shape2的维度不同,则首先会将维度小的shape进行高维度补1操作,然后再将两个维度相等的shape进行broadcast操作。

如果两个shape中存在维度值不相等的轴,且其中一个不为1,即不满足broadcast的原则,则抛出RuntimeError "input shapes not match!"

函数原型

def broadcast_shapes(shape1, shape2, op_name=OP_NAME, param_name_input1='', param_name_input2='')

参数说明

参数

说明

shape1

需要执行broadcast操作的shape。

shape2

需要执行broadcast操作的shape。

op_name

算子名称,用于打印信息时辅助提示,默认值OP_NAME为NULL。

param_name_input1

输入参数1的名字,用于打印信息时辅助提示,默认值为NULL。

param_name_input2

输入参数2的名字,用于打印信息时辅助提示,默认值为NULL。

返回值说明

返回广播后的shape1、shape2与broadcast shape。

约束说明

无。

调用示例

from tbe.common.utils import shape_util
shape1, shape2, shape3= shape_util.broadcast_shapes((3, 4, 1), (1,5)) 

调用完成后:

shape1 = (3, 4, 1)

shape2 = (1, 1, 5)

shape3 = (3, 4, 5)

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词