simplify_axis_shape
函数功能
输入shape及待reduce的轴列表,然后将输入shape中连续的reduce轴对应的shape维度进行合并,连续的非reduce轴对应的shape维度也进行合并。
并将输入的reduce轴列表中连续的轴进行合并。
最终返回包含轴合并后的shape及轴列表的二元组。
函数原型
def simplify_axis_shape(shape, axis)
参数说明
参数 |
说明 |
---|---|
shape |
输入shape |
axis |
待reduce的轴列表 |
返回值说明
二元组,包含轴合并后的shape和轴列表。
约束说明
无
调用示例
from tbe.common.utils import shape_util shape_util.simplify_axis_shape((32, 64, 64, 1,4, 5), [1,2])
输入轴列表为[1,2],所以针对输入shape(32, 64, 64, 1,4, 5),将轴1与轴2进行合并;输入shape中,轴3、4、5为连续的非reduce轴,所以将输入shape的轴3、4、5也进行合并,最终输出shape为(32, 4096, 20)。
将输入轴列表中的1和2进行合并,输出的轴列表为[1]。
所以最终返回shape和轴列表的2元组:(32, 4096, 20),[1]。
父主题: shape相关工具