算子实现
本章节将以Gather类算子为例,介绍SIMT算子实现的基本流程,如下图所示:

- 算子分析与核函数定义:明确算子的输入和输出,分析最大线程数的设置方案。明确算子核函数名、输入输出参数,确定动态参数空间大小,配置最大线程数。
- Host侧线程切分计算:根据输入数据的shape信息,计算和设置gridDim、blockDim等参数。
- Kernel侧算子实现:实现单线程内的计算逻辑。
下文将对上述步骤进行详细介绍。
算子分析与核函数定义
算子分析具体步骤如下:
- 明确算子的功能和计算逻辑。
gather算子的功能为从输入张量中获取指定索引行的数据,即从形状为M * N的二维向量input中获取指定索引的m行数据,这m行的行索引由输入index指定。算子输出output第i行数据的计算公式为:
output[i] = input[index[i]]
- 明确算子的输入和输出。
- gather算子有两个输入:input与index;输出为output。
- 本样例中算子输入input的数据类型支持float、half、int32_t,index的数据类型为uint32_t,算子输出的数据类型与输入的数据类型相同。
- 每个线程处理一行数据,需要传入每行数据的长度in_width以及需要处理的总行数index_total_length,以确保尾部线程不会进行无效操作。
- 在算子实现中,无需使用大量临时变量,为了提高性能,可以在默认最大线程数(1024)的基础上适当增大核函数的最大线程数。
- 明确函数名称和参数。
- 自定义核函数名称,本样例中核函数命名为gather_custom。
- 通过分析算子的输入和输出,使用模板参数来支持不同的输入输出数据类型。
模板参数名
模板参数类型
参数定义
type_data
typename
输入输出的数据类型
type_idx
typename
index的数据类型
函数入参定义如下:
参数名
参数类型
参数定义
input
type_data*
输入数据在Global Memory上的内存地址
index
type_idx*
索引数据在Global Memory上的内存地址
gather_output
type_data*
输出数据在Global Memory上的内存地址
in_width
uint32_t
输入数据第二维的长度(列宽)
index_total_length
uint32_t
index数据的总长度
- 明确SIMT核函数gridDim、blockDim等动态参数设置方案。
- 本样例采用均匀切分方案,根据可用核数和最大线程数的限制,计算和调整gridDim(启用的线程块的个数)、blockDim(一个线程块启用的的线程个数),同时保证gridDim不超过65535、blockDim不超过最大线程数2048。
- 本算子实现逻辑中无需使用动态UB空间。
通过以上分析,得到SIMT Gather算子的设计规格如下:
- 算子类型(OpType):Gather
- 算子输入输出:
表1 Gather算子输入输出规格 name
shape
data type
format
input(输入)
(M, N)
float/half/int32_t
ND
index(输入)
(m), m < M
uint32_t
ND
output(输出)
(m, N)
float/half/int32_t
ND
- 核函数名称:gather_custom
核函数定义如下:
1 2 3 4 5 6 7 8 9 | constexpr uint32_t MAX_THREAD_COUNT = 2048; template <typename type_data, typename type_idx> __global__ __launch_bounds__(MAX_THREAD_COUNT) void gather_custom( type_data* input, type_idx* index, type_data* gather_output, uint32_t in_width, uint32_t index_total_length) |
在定义核函数时,使用__launch_bounds__(MAX_THREAD_COUNT)来指定最大线程数。最大线程数的设置范围为1到2048。设置的最大线程数越大,支持启用的线程越多,性能越好,但每个线程可使用的内部寄存器数量会减少。若未设置,最大线程数默认值为1024。在上述分析中已明确计算不需要过多寄存器,因此设置最大线程数为2048。在实际的算子开发过程中,应根据具体的算子实现来调整该值。
Host侧线程切分计算
本样例以简单的均匀切分方案介绍如何实现动态切分参数的计算。
- 设置初始gridDim。考虑到如果gridDim设置得比实际AIV核数少,会导致空闲核浪费,因此将初始gridDim设置为当前芯片的实际AIV核数量。AIV数量的获取方法如下所示:
1 2 3 4
uint32_t real_core_num = 0; const auto& platformInfoMgr = platform_ascendc::PlatformAscendCManager::GetInstance(); real_core_num = platformInfoMgr->GetCoreNumAiv(); block_num = real_core_num; // block_num为初始gridDim
- 计算blockDim。
根据输入index的长度index_total_length、初始gridDim计算一个线程块启用的的线程个数blockDim。
// thread_num_per_block为blockDim值 thread_num_per_block = (index_total_length + block_num - 1) / block_num;
- 调整blockDim。
若blockDim超出最大线程数限制,调整blockDim值为最大线程数值。
1 2 3
if (thread_num_per_block > MAX_THREAD_COUNT) { thread_num_per_block = MAX_THREAD_COUNT; }
- 调整gridDim。
重新计算gridDim,确保gridDim * blockDim > index_total_length,即确保所有启用的线程能够处理完指定行数的数据。
1block_num = (index_total_length + thread_num_per_block - 1) / thread_num_per_block;
完整的切分计算代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | constexpr uint32_t MAX_THREAD_COUNT = 2048; constexpr uint32_t MAX_BLOCK_COUNT = 65535; bool block_splite(uint32_t index_total_length, uint32_t &block_num, uint32_t &thread_num_per_block) { uint32_t real_core_num = 0; const auto& platformInfoMgr = platform_ascendc::PlatformAscendCManager::GetInstance(); if (platformInfoMgr == nullptr) { std::cout << "[ERROR] Get plateform info failed, please check device status."<< std::endl; return false; } real_core_num = platformInfoMgr->GetCoreNumAiv(); block_num = real_core_num; thread_num_per_block= (index_total_length + block_num -1) / block_num; if (thread_num_per_block > MAX_THREAD_COUNT) { thread_num_per_block = MAX_THREAD_COUNT; block_num = (index_total_length + thread_num_per_block - 1) / thread_num_per_block; if (block_num > MAX_BLOCK_COUNT) { std::cout << "[ERROR] index_total_length: "<< index_total_length << " can not be bigger then " << MAX_THREAD_COUNT * MAX_BLOCK_COUNT<< "."<< std::endl; return false; } } return true; } |
Kernel侧算子实现
- 根据均匀切分算法,获取当前线程的位置偏移量。
在本算子中,仅使用gridDim、blockDim等线程维度的第一维,因此计算偏移量时只需考虑x维信息。如下代码所示,threadIdx表示线程在其所在线程块内的索引,blockDim表示一个线程块中设置的线程数,而blockIdx表示线程块的索引。
1 2
// 计算线程索引 int32_t out_row = blockIdx.x * blockDim.x + threadIdx.x;
- 根据线程索引,获取当前线程需要处理数据的行索引,计算对应的输入、输出位置偏移量,实现整行数据的获取采集。
1 2 3 4 5 6 7 8
uint32_t in_row = index[out_row]; int input_idx = in_row * in_width; int output_idx = out_row * in_width; for (int32_t col = 0; col < in_width; col++) { gather_output[output_idx] = input[input_idx]; input_idx += 1; output_idx += 1; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | constexpr uint32_t MAX_THREAD_COUNT = 2048; constexpr uint32_t MAX_BLOCK_COUNT = 65535; template <typename type_data, typename type_idx> __global__ __launch_bounds__(MAX_THREAD_COUNT) void gather_custom( type_data* input, type_idx* index, type_data* gather_output, uint32_t in_width, uint32_t index_total_length) { // Calculate global thread ID int32_t out_row = blockIdx.x * blockDim.x + threadIdx.x; // Maps to the row index of output tensor if (out_row >= index_total_length) { return; } // Single thread processes entire row (all columns) - enables coalesced memory access uint32_t in_row = index[out_row]; int input_idx = in_row * in_width; int output_idx = out_row * in_width; for (int32_t col = 0; col < in_width; col++) { gather_output[output_idx] = input[input_idx]; input_idx += 1; output_idx += 1; } } |
运行验证
核函数即算子Kernel程序开发完成后,即可编写Host侧的核函数调用程序,实现从Host侧的APP程序调用算子,进行运行验证。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | std::vector<float> gather(std::vector<float>& input, const uint32_t* in_shape, std::vector<uint32_t>& index) { ... // 计算切分参数,设置动态UB内存 uint32_t block_num = 0; uint32_t thread_num_per_block = 0; block_splite(index_total_length, block_num, thread_num_per_block)) ... // 计算切分参数,设置动态UB内存 uint32_t dyn_ubuf_size = 0; // No need to alloc dynamic memory. // 用内存调用符<<<...>>>调用核函数完成指定的运算 gather_custom<<<block_num, thread_num_per_block, dyn_ubuf_size, stream>>>( input_device, index_device, output_device, in_shape[1], index_total_length); ... } |