shape_refine

函数功能

返回非1维度值组成的shape及排序后的reduce轴列表,格式为(shape,axis)的元组。

函数原型

def shape_refine(shape, reduce_axis=None, keep_dims=True):

参数说明

参数

说明

shape

输入数据shape

reduce_axis

需要进行reduce的轴,可以是list、tuple或者int类型数值。

keep_dims

是否保持维度数,bool型。

True代表降维。

返回值说明

约束说明

reduce轴对应的shape值为1,则不能进行优化。

调用示例

from te.utils import shape_util 
shape_util.shape_refine((32, 64, 64, 1), reduce_axis=None) 

返回(32, 64, 64)。