broadcast关系
广播概念
broadcast(广播)描述了算子在运算期间如何处理不同形状的张量(或数组)。大部分情况下,允许不同形状的张量(或数组)在进行元素操作时自动扩展其形状,使其维度相互兼容,通常较小的张量(或数组)会“广播”为较大的张量(或数组)。
目前许多CANN算子API参数shape支持广播,可适当提高计算效率、减少内存占用(尤其大规模数据场景),更详细的广播技术介绍请参考NumPy官网。
广播规则
一般进行广播计算时,需要理解以下规则:
规则1:如果数组间维度数不一致,所有数组向最长形状的数组看齐,形状不足的部分在左侧填充1,直至维度数相同。
- 说明1:
维度数(Number of Dimensions)是指张量(或数组)对应shape的维数,比如x.shape=(1,1,2,4),其维度数是4 。
说明2:
比如计算a+b,其中a.shape=(2, 2, 3)、b.shape=(2, 3),那么数组b将被broadcast为b.shape=(1, 2, 3)。
规则2:如果数组间维度数一致,且某个数组的某一维度为1,则该维度为1的数组将被拉伸以匹配另一个数组对应维度形状。
说明:
本场景下,只需保证在某一维度做broadcast即可。比如计算a+b,其中a.shape=(1, 3)、b.shape=(3, 1),那么两个数组会broadcast为a.shape=(3, 3)、b.shape=(3, 3)。
规则3:如果数组间维度数不一致,且均没有等于1的维度,则会报错。
基于上述规则,广播过程一般先按规则1进行扩维,再按规则2进行形状拉伸,具体例子如下:
假设a.shape=(2,2,3),取值形如:
[[[1 2 3],[4 5 6]],
[[1 2 3],[4 5 6]]]
假设b.shape=(2,3),取值形如:
[[1 2 3],
[-1 -2 -3]]
根据规则1扩展维度,b.shape=(1,2,3),取值如下:
[[[1 2 3],
[-1 -2 -3]]]
根据规则2拉伸形状,b.shape=(2,2,3),取值如下:
[[[1 2 3],[-1 -2 -3]],
[[1 2 3],[-1 -2 -3]]]
计算a+b,实际结果如下:
[[[2 4 6],[3 3 3]],
[[2 4 6],[3 3 3]]]
限制
当满足broadcast关系的两个输入a和b的数据类型或推导后的数据类型在COMPLEX64、COMPLEX128、DOUBLE、INT16、UINT16、UINT64中时,除了满足上述广播规则,还需满足如下条件,否则广播会失败,导致算子执行报错。 条件:连续的需要广播的轴和连续的不需要广播的轴合并之后的维度要求小于6。 举例:
- 当a.shape=(5, 1, 5, 1, 5, 1),b.shape=(5, 5, 5, 5, 5, 5),没有需要合并的轴,最后维度为6,广播报错。
- 当a.shape=(5, 1, 5, 5, 1, 1),b.shape=(5, 5, 5, 5, 5, 5),在第2和3维都不需要广播,4和5维都需要广播,分别连续合并,合并后的维度为4,广播成功。