shape_refine
函数功能
返回非1维度值组成的shape及排序后的reduce轴列表,格式为(shape,axis)的元组。
函数原型
def shape_refine(shape, reduce_axis=None, keep_dims=True):
参数说明
参数 |
说明 |
---|---|
shape |
输入数据shape |
reduce_axis |
需要进行reduce的轴,可以是list、tuple或者int类型数值。 |
keep_dims |
是否保持维度数,bool型。 True代表降维。 |
返回值说明
- 如果输入只有shape,则输出也只有shape。
- 如果输入同时有shape和reduce_axis,则输出是元素组(shape,axis)。
约束说明
reduce轴对应的shape值为1,则不能进行优化。
调用示例
from te.utils import shape_util shape_util.shape_refine((32, 64, 64, 1), reduce_axis=None)
返回(32, 64, 64)。
父主题: shape相关工具