把连续的reduce轴进行合并,并把对应的shape的维度也进行合并。
def simplify_axis_shape(shape, axis)
参数 |
说明 |
---|---|
shape |
数据shape |
axis |
待reduce的轴列表 |
二元组,包含合并轴后的shape和轴列表。
无
from te.utils import shape_util shape_util.simplify_axis_shape((32, 64, 64, 1), [1,2])
把第二、三轴合并,返回(32,4096,1),[1]