broadcast_shapes
函数功能
将两个输入shape:shape1与shape2进行广播,并返回shape1、shape2与broadcast shape。
broadcast shape取shape1与shape2每个维度的大值。
若shape1与shape2的维度不同,则首先会将维度小的shape进行高维度补1操作,然后再将两个维度相等的shape进行broadcast操作。
如果两个shape中存在维度值不相等的轴,且其中一个不为1,即不满足broadcast的原则,则抛出RuntimeError "input shapes not match!"
函数原型
def broadcast_shapes(shape1, shape2, op_name=OP_NAME, param_name_input1='', param_name_input2='')
参数说明
参数 |
说明 |
---|---|
shape1 |
需要执行broadcast操作的shape。 |
shape2 |
需要执行broadcast操作的shape。 |
op_name |
算子名称,用于打印信息时辅助提示,默认值OP_NAME为NULL。 |
param_name_input1 |
输入参数1的名字,用于打印信息时辅助提示,默认值为NULL。 |
param_name_input2 |
输入参数2的名字,用于打印信息时辅助提示,默认值为NULL。 |
返回值说明
返回广播后的shape1、shape2与broadcast shape。
约束说明
无。
调用示例
from tbe.common.utils import shape_util shape1, shape2, shape3= shape_util.broadcast_shapes((3, 4, 1), (1,5))
调用完成后:
shape1 = (3, 4, 1)
shape2 = (1, 1, 5)
shape3 = (3, 4, 5)
父主题: shape相关工具