【CANN文档速递06期】初识TBE DSL算子开发
发表于 2022/05/13
昇腾开发者可以基于TBE(Tensor Boost Engine)提供的python接口进行自定义算子开发,TBE算子开发有DSL ( Domain-Specific Language )和TIK ( Tensor Iterator Kernel )两种方式,两种方式的适用场景及优缺点如下所示:
本期我们主要介绍较容易上手的DSL算子开发方式。
DSL功能框架
1. 开发者调用DSL接口进行计算逻辑的描述,指明算子的计算方法和步骤。
2. 计算逻辑开发完后,开发者可调用Auto Schedule接口启动自动调度,自动调度时TBE根据计算类型自动选择合适的调度模板,完成数据切块和数据流向的划分,确保在硬件执行上达到最优。调度完成后,会生成类似于TVM的IR(Intermediate Representation)中间表示。
3. IR生成后,Pass会自动对生成的IR进行编译优化,优化的方式有双缓冲(Double Buffer)、流水线(Pipeline)同步、内存分配管理、指令映射等。
4. 算子经Pass处理后,会自动由CodeGen生成类C代码的临时文件,这个临时文件再通过编译器生成算子的二进制文件,可被网络模型直接加载调用。
DSL算子代码实现
在实现算子代码之前需要分析算子的数学表达式,输入、输出,明确需要调用的DSL接口,然后进行算子实现代码的开发。基于DSL的代码实现流程如下图所示:
算子实现的代码结构如下所示:
下面我们以实现两输入shape相同、数据类型为“float32”的Add算子为例,讲述DSL算子实现的代码流程。
1. 首先引入开发时依赖的Python模块
常用的Python模块如下所示:
· “tbe.dsl”:包含TBE DSL的计算接口、调度接口以及编译接口
· “tbe.tvm”:TBE是基于TVM框架扩展而来的,开发者在实现算子的时候可以使用TVM接口
· “tbe.common.utils.para_check”:TBE提供的算子参数校验接口
· “tbe.common.utils.shape_util”:TBE提供的算子shape处理接口
2. 声明算子接口
算子接口定义函数中包含算子的输入输出信息以及内核名称。
如下所示为Add算子的定义:
上述示例中,add为算子的type,input_x、input_y为算子的输入输出tensor,采用字典的形式定义,包含shape、ori_shape、format、ori_format与dtype信息,kernel_name为算子在内核中的名称,与算子type保持一致即可。
开发者在定义算子接口函数时可以使用TBE提供的参数校验接口check_input_type校验算子的参数类型是否合法,check_input_type为装饰器函数,使用方法如下所示:
当然,您也可以自定义实现相关参数的校验功能,基本的参数校验有助于在算子编译阶段提前发现问题。
3. 对输入tensor进行占位
获取输入数据的shape、dtype(此示例为float32的固定数据类型),使用TVM的placeholder接口对输入tensor进行占位,返回一个tensor对象,此位置中的数据在程序运行时才被指定。
4. 进行计算逻辑的实现
算子的计算逻辑可以通过TBE的DSL计算接口实现,例如Add算子可以通过DSL的vadd接口实现input_x与input_y的相加操作。
5. 调度与编译
计算逻辑实现完成后,需要调用auto_schedule接口,自动生成相应的调度;然后调用build接口进行算子的编译,编译出算子专用内核。
其中config为编译参数配置的map,配置信息包括是否需要打印IR、算子内核名称以及输入、输出张量。
恭喜您,至此您已经完成了Add算子实现代码的开发。
更多介绍
以上仅对DSL算子开发的关键代码进行了简要介绍,更多算子实现时的细节及技巧可登录昇腾社区,阅读相关文档:https://www.hiascend.com/
昇腾CANN文档中心致力于为开发者提供更优质的内容和更便捷的开发体验,助力CANN开发者共建AI生态。任何意见和建议都可以在昇腾社区反馈,您的每一份关注都是我们前进的动力。
本页内容