Quick Start
When developing TBE operators in DSL mode, you only need to focus on the compute logic of the operators without the need to pay attention to the scheduling strategy, making this development mode simple and convenient.
Objectives
This section describes the method of writing the implementation code of a TBE operator in DSL mode by using the Add operator as an example.
The add operator returns the sum of its operands, as shown in the following figure.
Operator Analysis
Before developing an Add operator by using the TBE DSL, determine the operator functionality, inputs, outputs, select an operator development mode, and name the operator type and implementation function.
- Specify the operator function and mathematical expression.
Specify the mathematical expression of the Add operator as follows:
y=x1+x2
The Add operator adds two inputs and returns the result.
- Specify the inputs and output.
- The Add operator has two inputs x1 and x2, and one output y.
- The supported input data types include float16, float32, and int32. The output has the same data type as the inputs.
- The operator inputs support all shapes. The output has the same shape as the inputs.
- The operator input supports the following formats: NCHW, NC1HWC0, NHWC, and ND.
- Determine the operator development mode and the compute API.
- The compute process involves only the addition operation. For details, see TBE DSL API. The tbe.dsl.vadd(lhs, rhs) API can be used to implement "x + y" for preliminary analysis.
- The tbe.dsl.vadd(lhs, rhs) API requires that the two input tensors have the same shape. Therefore, you need to obtain the larger shape of the two input tensors, and then call the tbe.dsl.broadcast(var, shape, output_dtype=None) API to broadcast the input tensors to the specified shape.
- Specify the operator implementation file name, operator implementation function name, and OpType.
- Name OpType in upper camel case and separate words with a single capitalized letter.
- Name the operator implementation file and operator definition function in either of the following ways:
- To create user-defined names, configure opFile.value and opInterface.value in TBE Operator Information Library.
- If opFile.value and opInterface.value in the TBE Operator Information Library are not configured, FE obtains the operator file name and function name by replacing the OpType as follows.
The rules are as follows:
- Replace the first uppercase letter with a lowercase letter.
- Replace each uppercase letter following lowercase letters with an underscore (_) and the corresponding lowercase letter.
- Uppercase letters following a digit or an uppercase letter are regarded as a semantic string. If there is a lowercase letter after this string, replace the last uppercase letter in this string with an underscore (_) and the corresponding lowercase letter, and replace the other uppercase letters with corresponding lowercase letters. If there is no lowercase letter after the string, directly replace the string with lowercase letters.
Examples: ABCDef -> abc_def; Abc2DEf -> abc2d_ef; Abc2DEF -> abc2def; ABC2dEF -> abc2d_ef
In this example, OpType of the operator is defined as Add. Uncapitalize the first letter to obtain the operator implementation file name and implementation function name, that is, add.
Based on the preceding analysis, the design specifications of the Add operator are as follows.
Table 1 Add operator's design specifications OpType
Add
Operator Input
Name: x1
Shape: all
Data type:
float16, float32, int32
Format:
NCHW, NC1HWC0, NHWC, ND
Name: x2
Shape: all
Data type:
float16, float32, int32
Format:
NCHW, NC1HWC0, NHWC, ND
Operator Output
Name: y
Shape: all
Data type:
float16, float32, int32
Format:
NCHW, NC1HWC0, NHWC, ND
Main DSL APIs for Operator Implementation
tbe.dsl.broadcast(var, shape, output_dtype=None)
tbe.dsl.vadd(lhs, rhs)
Operator Implementation File/Function Name
add
Operator Code Implementation
The Add operator supports only three data types: float16, float32, and int32. Therefore, the data type of each input needs to be verified. However, the two inputs may have different shapes, which are supported by the Add operator, but not supported by the operator compute API tbe.dsl.vadd( ). As a result, the two input shapes need to be broadcast and verified. The operator implementation code is as follows.
import tbe
from tbe import tvm
from tbe import dsl
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
from functools import reduce
SHAPE_SIZE_LIMIT = 2147483648
# Implement the compute logic of the Add operator.
@tbe.common.register.register_op_compute("add",op_mode="static")
def add_compute(input_x, input_y, output_z, kernel_name="add"):
shape_x = shape_util.shape_to_list(input_x.shape) # Convert the shape to a list.
shape_y = shape_util.shape_to_list(input_y.shape) # Convert the shape to a list.
shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y,param_name_input1="input_x",param_name_input2="input_y") # Assign the larger value of each dimension of shape_x and shape_y to shape_max.
shape_size = reduce(lambda x, y: x * y, shape_max[:])
if shape_size > SHAPE_SIZE_LIMIT:
raise RuntimeError("the shape is too large to calculate")
input_x = dsl.broadcast(input_x, shape_max) # Broadcast the shape of input_x to shape_max.
input_y = dsl.broadcast(input_y, shape_max) # Broadcast the shape of input_y to shape_max.
res = dsl.vadd(input_x, input_y) # Execute input_x + input_y.
return res # Return the output tensor.
# Operator definition function
def add(input_x, input_y, output_z, kernel_name="add"):
# Obtain the shape and data type of the operator input tensor.
shape_x = input_x.get("shape")
shape_y = input_y.get("shape")
check_tuple = ("float16", "float32", "int32")
input_data_type = input_x.get("dtype").lower()
if input_data_type not in check_tuple:
raise RuntimeError("only support %s while dtype is %s" %
(",".join(check_tuple), input_data_type))
# Assign the larger value of each dimension of shape_x and shape_y to shape_max.
shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y,param_name_input1="input_x",param_name_input2="input_y")
if shape_x[-1] == 1 and shape_y[-1] == 1 and shape_max[-1] == 1:
# If the shape length is 1, assign a value directly. If the shape length is not 1, the shape needs to be tiled and the last dimension needs to be removed. For the shape with the last dimension of 1 and the shape without the last dimension, if their formats are the same, for example, 2 x 3 = 2 x 3 x 1, the last dimension can be removed to improve the schedule efficiency.
shape_x = shape_x if len(shape_x) == 1 else shape_x[:-1]
shape_y = shape_y if len(shape_y) == 1 else shape_y[:-1]
shape_max = shape_max if len(shape_max) == 1 else shape_max[:-1]
# Call the placeholder API of TVM to place the first input tensor, returning a tensor object.
data_x = tvm.placeholder(shape_x, name="data_1", dtype=input_data_type)
# Call the placeholder API of TVM to place the second input tensor, returning a tensor object.
data_y = tvm.placeholder(shape_y, name="data_2", dtype=input_data_type)
# Call the compute implementation function.
res = add_compute(data_x, data_y, output_z, kernel_name)
# Auto schedule
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
# Build configuration
config = {"name": kernel_name,
"tensor_list": (data_x, data_y, res)}
dsl.build(schedule, config)
Operator Building Verification
- Append the main function to the operator Python file to call the operator. The code example is as follows:
1 2 3 4
# Call the operator. if __name__ == '__main__': input_output_dict = {"shape": (5, 6, 7),"format": "ND","ori_shape": (5, 6, 7),"ori_format": "ND", "dtype": "float16"} add(input_output_dict, input_output_dict, input_output_dict, kernel_name="add")
- Run the following command to build the operator implementation file and verify the syntax of the implementation code:
python3 add.py
If no build error is reported and a kernel_meta folder containing the following files is generated in the current directory, the operator code can be built and run properly.- Operator binary file add.o
- Operator description file add.json: defines operator attributes and runtime resources.