返回非1维度值组成的shape及排序后的reduce轴列表,格式为(shape,axis)的元组。
def shape_refine(shape, reduce_axis=None, keep_dims=True):
参数 |
说明 |
---|---|
shape |
输入数据shape |
reduce_axis |
需要进行reduce的轴,可以是list、tuple或者int类型数值。 |
keep_dims |
是否保持维度数,bool型。 True代表降维。 |
reduce轴对应的shape值为1,则不能进行优化。
from te.utils import shape_util shape_util.shape_refine((32, 64, 64, 1), reduce_axis=None)
返回(32, 64, 64)。