另外,对于支持任意维度输入的tensor,如果搬运或者计算指令的地址偏移量需要每一个维度都参与计算,我们可以使用迭代函数来实现任意维度的迭代处理,从而计算出搬运偏移量。
例如,对于reverse算子,需要实现的一个功能是对一个最高输入维度为8维的算子的任意一个或多个轴做逆序,在做shape泛化的时候,就要考虑两个变量:一个是输入的tensor维度不定,一个是需要逆序的轴的位置个数不定。因此,如果直接通过普通的for循环迭代的方法可能要写多个分支来处理。在实际开发过程中,我们可以通过迭代的方式来处理多个维度。例如下面的终止条件是否遍历到最后一轴,如果没有遍历到最后一轴,就叠加for循环,并计算搬运索引;如果遍历到最后一轴,就进行相应的指令计算。
#定义迭代循环函数reverse_big_shape def reverse_big_shape(self, outer_loop_shape, move_in_index, move_out_index, loop_axis): """ Traverse the outer loop of tensor Parameters ---------- outer_loop_shape: the shape of outer loop move_in_index: index for moving input data from gm to ub move_out_index: index for moving output data from ub to gm loop_axis: loop index currently traversed Returns ------- None """ inner_data_num = functools_reduce(lambda x, y: x * y, self.inner_shape) if loop_axis == 0 and inner_data_num > 32 and self.shape_x[0] < 65536: with self.tik_instance.for_range(0, outer_loop_shape[0], block_num=self.outer_shape[0]) as index: #根据每次循环的输入刷新指令计算的索引 move_in_index, move_out_index = self.get_move_index(loop_axis, move_in_index, move_out_index, outer_loop_shape, index) #判断是否遍历到最后一轴,没有遍历到最后一轴,就叠加for循环reverse_big_shape if len(outer_loop_shape) > 1: self.reverse_big_shape(outer_loop_shape[1:], move_in_index, move_out_index, loop_axis + 1) #判断是否遍历到最后一轴,遍历到最后一轴,就进行相应的指令计算reverse_last_axis else: self.reverse_last_axis(move_in_index, move_out_index) else: with self.tik_instance.for_range(0, outer_loop_shape[0]) as index: #根据每次循环的输入刷新指令计算的索引 move_in_index, move_out_index = self.get_move_index(loop_axis, move_in_index, move_out_index, outer_loop_shape,index) #根据每次循环的输入刷新指令计算的索引 if len(outer_loop_shape) > 1: self.reverse_big_shape(outer_loop_shape[1:], move_in_index, move_out_index, loop_axis + 1) #判断是否遍历到最后一轴,遍历到最后一轴,就进行相应的指令计算reverse_last_axis else: self.reverse_last_axis(move_in_index, move_out_index) #定义刷新指令计算get_move_index def get_move_index(self, loop_axis, move_in_index, move_out_index, outer_loop_shape, index): """ Get the offset of reading and writing UB Parameters ---------- loop_axis: the number of the axis currently traversed move_in_index: the offset to read data from data_x_gm move_out_index: the offset to write data to data_x_ub outer_loop_shape: the outer loop shape of the current traversal index: current traversed index Returns ------- move_in_index: the offset to read data from data_x_gm move_out_index: the offset to write data to data_x_ub """ if loop_axis in self.axis: move_in_index = move_in_index * outer_loop_shape[0] + outer_loop_shape[0] - 1 - index else: move_in_index = move_in_index * outer_loop_shape[0] + index move_out_index = move_out_index * outer_loop_shape[0] + index return move_in_index, move_out_index