代码结构介绍
TBE DSL方式实现的算子代码结构如下所示:
# 导入依赖的Python模块
from tbe import dsl
from tbe import tvm
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
# 若有其他的python依赖,请自行导入
# 算子计算函数
# 装饰器函数tbe.common.register.register_op_compute可选,若算子实现逻辑中涉及reshape操作,不可使用此装饰器函数
@tbe.common.register.register_op_compute("add",op_mode="static")
def add_compute(input_x, input_y, output_z, kernel_name="add"):
"""
算子计算逻辑实现
"""
# 算子定义函数
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT,para_check.KERNEL_NAME)
def add(input_x, input_y, output_z, kernel_name="add"):
"""
算子校验(可选)
为输入tensor占位
"""
res = add_compute(data_x, data_y, output_z, kernel_name) # 调用算子计算函数
# 自动调度
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# 算子编译
config = {"name": kernel_name,
"tensor_list": (data_x, data_y, res)}
dsl.build(schedule, config)
# 可选,若实现此函数,且算子信息库中的input的dtype与format的dynamicFormat.flag配置为true,则可在算子融合阶段调用此函数实现dtype与format的推导。
def op_select_format(input_x, input_y, output_y, kernel_name="add"):
...
...
# 可选,若实现此函数,且算子信息库中的needCheckSupport的flag参数配置为true,则可在算子融合阶段调用此函数实现算子的dtype与shape的校验。
def check_supported(input_x, input_y, output_y, kernel_name="add"):
...
...
算子实现代码总体结构包含依赖Python模块的导入,算子定义函数实现,算子计算函数实现。
其中:
- 算子定义函数包含算子的校验,计算函数的调用以及调度与编译。
- 算子计算函数是对算子计算逻辑的实现。
下面详细介绍每个代码块的实现。
父主题: 算子代码实现(TBE DSL)