Codes Illustrated

The code for implementing a TBE operator in DSL mode is as follows.

# Import the dependent Python modules.
from tbe import dsl
from tbe import tvm
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
# You can import more Python dependencies as needed.

# Operator compute function
# (Optional) The wrapper function tbe.common.register.register_op_compute. If operator reshape is involved in the operator implementation logic, this function is unavailable.
@tbe.common.register.register_op_compute("add",op_mode="static")
def add_compute(input_x, input_y, output_z, kernel_name="add"):
    """
    Implementation of the operator compute logic
    """
# Operator definition function
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT,para_check.KERNEL_NAME)
def add(input_x, input_y, output_z, kernel_name="add"):

    """
    (Optional) Operator verification
    Insert placeholders for the input tensors
    """

    res = add_compute(data_x, data_y, output_z, kernel_name) # Call the operator compute function.

    # Auto schedule
    with tvm.target.cce(): 
        schedule = dsl.auto_schedule(res)        
    # Operator build
    config = {"name": kernel_name,
              "tensor_list": (data_x, data_y, res)}
    dsl.build(schedule, config)

# (Optional) If this function is implemented and dynamicFormat.flag is set to true for the inputs' dtype and format arguments in the operator information library, this function can be called to infer the dtype and format arguments during operator fusion.
def op_select_format(input_x, input_y, output_y, kernel_name="add"):
    ...
    ...
# Optional. If this function is implemented and the flag parameter of needCheckSupport in the operator information library is set to true, this function can be called to verify the dtype and shape arguments during operator fusion.
def check_supported(input_x, input_y, output_y, kernel_name="add"):
    ...
    ...

The overall structure of the operator implementation code is used to import the Python modules and implement the operator definition function and operator compute function.

In the preceding code:

  • The operator definition function covers the operator verification, calling of the compute function, scheduling, and build.
  • The operator compute function implements the operator compute logic.

The following section describes the implementation of each code block in detail.