在某些场景下,可能会存在两个输入shape不相同的情况。由于接口只支持对shape相同的输入进行计算,因此需要先对输入进行shape变换,再进行Add计算。本节将对满足Broadcast条件的输入在算子实现中的Broadcast处理进行介绍,其他场景可以参考本章节中提供的思路。
[object Object]
本节中将使用接口,因此输入需满足该API相关约束。同时,由于硬件限制,该API的输入地址需满足32字节对齐。本节以输入维度为2、第二个轴(axis = 1)需要Broadcast为例进行说明。完整的样例代码请参见。
与输入shape相同的场景相比,在Tiling结构体中增加相应的成员变量,表示是否需要对输入进行Broadcast、需要对哪个维度进行Broadcast、Broadcast的轴需要扩充的倍数。因此新增四个Tiling结构体成员:
- xLen和yLen:表示两个输入的数据长度。
- axis:表示对输入的哪个维度进行Broadcast。
- coef:表示Broadcast的输入需要扩维的倍数。例如,x shape为(m, 1),y shape为(m, n),则coef = n。如下图所示,图中相同颜色部分为单次计算的数据块。
图 1 axis=1时coef示意图[object Object][object Object]
Tiling结构体定义代码如下所示:
设需要进行Broadcast的输入长度为shorterAxisLen;不需要进行Broadcast的输入长度为totalLength。
使用shorterAxisLen进行分核计算,并使用分核后的长度与coef相乘作为totalLength的分核长度。
进行核内数据切分时,需要计算Unified Buffer数据块的数量向coef和BUFFER_NUM对齐之后的数量ubBlockAligned。
在核函数初始化阶段,根据Tiling结构体传入的参数确定对哪个输入进行Broadcast。由于针对输入的第二个轴(axis = 1)进行Broadcast,可以计算出,对于需要进行Broadcast的输入,每个核搬入数据长度为blockLength / coef。
初始化函数代码如下:
由于数据是向coef对齐的,在数据拷贝的过程中可能会出现地址不满足32字节对齐的场景,因此CopyIn函数、CopyOut函数中使用进行数据拷贝。
CopyIn函数实现代码如下:
CopyOut函数实现代码如下:
在Compute函数中,调用Add接口前需要先对输入进行Broadcast。这里需要计算Broadcast前后的shape。基于前文提到的数据关系,可以计算得出Broadcast前后的shape分别为{tileLength / broadcastCoef, 1}和{tileLength / broadcastCoef, broadcastCoef}。在此基础上对输入进行Broadcast,并将计算结果存入临时空间中,然后进行Add计算。实现代码示例如下所示: