文档
注册
评分
提单
论坛
小AI

refine_shape_axes

函数功能

对axes(reduce轴列表)进行负值的优化,并将shape按照如下原则进行fuse操作:

对连续的reduce轴,与连续的非reduce轴分别进行fuse操作。

axes中每个axis的值的范围需要在[-rank, rank)之间,否则将抛出RuntimeError:Axis must between [-%d, %d).

函数原型

def refine_shape_axes(shape, axes)

参数说明

参数

说明

shape

需要优化的shape

axes

Reduce轴的列表。

返回值说明

  • shape:进行fuse后的shape。
  • axes:进行负值优化后的axes。

约束说明

无。

调用示例

from tbe.common.utils import shape_util
shape, axes= shape_util.refine_shape_axes((2, 3, 4, 5, 6),(1,-3)) 
  1. 首先对axes(1,-3)进行负值优化,由于shape(2, 3, 4, 5, 6)共有5维,所以-3为倒数第三个维度,即2,则优化后的reduce列表axes=(1, 2)。
  2. 对shape(2, 3, 4, 5, 6)进行fuse操作:

    对连续的reduce轴1、轴2进行fuse操作,对连续的非reduce轴3、轴4进行fuse操作,操作完成后shape=(2, 12,30),axes=(1, 2) fuse成了轴1。

    所以返回值的shape与axes如下:

    shape=(2, 12, 30)

    axes=(1)

搜索结果
找到“0”个结果

当前产品无相关内容

未找到相关内容,请尝试其他搜索词