Compute Implementation
Compute implementation includes importing dependent Python modules, declaring operator functions, verifying operator arguments, implementing the compute logic, and scheduling and building. A range of measures for optimizing the operator accuracy and performance are available in the compute implementation phase.
Python Module Import
Before developing a TBE DSL operator, import the Python modules provided by the Ascend AI Software Stack to the operator implementation file. The following is a code example. For details about the naming rules of the operator implementation file, see Naming Rules for Operator Definition File.
from tbe import dsl from tbe import tvm from tbe.common.utils import para_check from tbe.common.utils import shape_util
In the preceding code:
- tbe.dsl: imports the DSL APIs supported by TBE, including common compute APIs, schedule APIs, and build APIs. For details about how to use the DSL API, see TBE DSL API.
- tbe.tvm: imports the backend code generation mechanism of TVM. Visit here for more details.
- tbe.common.utils.para_check: provides common argument verification APIs. For details about the API definition, see Operator Argument Verification.
- tbe.common.utils.shape_util: provides some common APIs for processing operator shape. For details about the API definition, see Shape-related Tools.
You can import more Python modules as needed.
Operator Function Declaration
The code implementation of the operator includes two functions: operator definition function and operator compute function. The operator compute function is called in the operator definition function.
The following details the declaration rules of these two functions.
- Declaring the operator definition function
As shown in the following code block, an operator definition function contains the operator input information, operator output information, and kernel name. The declaration information of the function must be consistent with that in the Operator Prototype Definition.
def operationname(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="KernelName", impl_mode="high_performance")
- The operator definition function name (operationname) must be the same as the operator implementation file name in the current version. For details about the naming rules, see Naming Rules for Operator Definition File.
- input_x1, input_x2: input tensors of the operator. A tensor must be defined in dictionary format, including the shape, ori_shape, format, ori_format, and dtype information. See the following example:
dict input_x1 = {'shape' : (2,2), 'ori_shape' : (2,2), 'format': 'ND', 'ori_format':'ND', 'dtype' : 'float16'}
The sequence and number of input tensors must be the same as those in Operator Prototype Definition. Optional inputs also need to be defined here. The compute logic determines whether data is transferred and processed accordingly.
- output_y: reserved. A dictionary for the output tensor of the operator, including the shape and dtype information.
The sequence and number of output tensors must be the same as those in Operator Prototype Definition. Optional inputs also need to be defined here.
- attribute1, attribute2...: operator attributes. The sequence and number of operator attributes must be the same as those in Operator Prototype Definition.
Ignore this parameter if the operator does not have attributes; assign the default values for this parameter if the attributes are optional.
- kernel_name: unique name of the operator in the kernel, that is, the name of the generated binary file and operator description file. The value can contain a maximum of 200 characters starting with a letter or underscore (_) and must be a combination of letters, digits, and underscores (_).
- impl_mode: (optional) a string, which specifies the implementation mode. This field affects the accuracy and performance only when the input is of type float32.
The value can be set to high_precision or high_performance (default).
If the input value is within 65,504 (the maximum value of float16), there is no accuracy concern. Otherwise, the exp arithmetic overflow error occurs, causing an inaccurate compute result. In this case, set this field to the high_precision mode, although the performance might be compromised.
For details about the impact of this configuration on operator execution on a network, see the op_select_implmode option in ATC Instructions.
The function of the Sqrt operator without attributes is declared as follows.
def sqrt(input_x, output_y, kernel_name="sqrt"):
The function of the reduce_sum operator with attributes is declared as follows.
def reduce_sum(x, y, axis=None, keep_dims=None, kernel_name="reduce_sum")
When declaring the operator definition function, you can perform basic verification on operator arguments using the wrapper function check_op_params or check_input_type.
check_op_params checks whether the operator inputs and outputs follow the rules of required and optional inputs and outputs. check_input_type validates the data types of the operator arguments.
See the following examples:
@para_check.check_op_params(para_check.REQUIRED_INPUT, para_check.REQUIRED_OUTPUT, para_check.KERNEL_NAME) def sqrt(input_x, output_y, kernel_name="sqrt")
@para_check.check_input_type(dict, dict, dict, int, bool, str) def sort(x, y1, y2, axis=-1, descending=False, kernel_name="sort")
- Declaring the compute function
@tbe.common.register.register_op_compute("KernelName",op_mode="static") def operationname_compute(input_x1, input_x2, output_y, attribute1=None, attribute2=None,..., kernel_name="KernelName")- The @tbe.common.register.register_op_compute wrapper performs automatic UB fusion within the network at run time. In this way, the current custom operator can be automatically fused according to the UB fusion patterns in the UB, improving the operator execution efficiency. For details, see the API description in register_op_compute.
If the reshape operation is involved in the operator implementation logic, automatic UB fusion is not supported. Do not use this wrapper when declaring the operator compute function.
- input_x1, input_x2: arguments passed to the compute function, that is, the placeholders for the input tensor declared in operator definition function, including information such as the shape and data type.
- Arguments for output_y, attribute1 and more are transparently passed from the operator definition function during the operator definition function declaration. Ensure that the parameters are consistent with declared ones.
For example, for the Sqrt operator, the compute function is defined as follows.
@tbe.common.register.register_op_compute("sqrt",op_mode="static") def sqrt_compute(input_data, output_data, kernel_name="sqrt"):For the reduce_sum operator, the operator API and compute function are defined as follows:
@tbe.common.register.register_op_compute("reduce_sum",op_mode="static") def reduce_sum_compute(x, y, axis=None, keep_dims=None, kernel_name="reduce_sum") - The @tbe.common.register.register_op_compute wrapper performs automatic UB fusion within the network at run time. In this way, the current custom operator can be automatically fused according to the UB fusion patterns in the UB, improving the operator execution efficiency. For details, see the API description in register_op_compute.
Operator Function Implementation
After the operator function is declared, the operator definition function and compute function need to be implemented.
- Obtain the shape and data type of the operator input tensor from the operatorname( ) call, and implement the verification function.
- Obtain the shape and data type of the operator input tensor, which will be used to define the input tensor placeholder.
def add(input_x, input_y, output_z, kernel_name="add"): shape_x = input_x.get("shape") shape_y = input_y.get("shape") input_data_type = input_x.get("dtype").lower() input_data_type_y = input_y.get("dtype").lower() - (Optional) Add the verification of operator input/output and basic attributes to the operator implementation function, helping find problems at operator build time.
Take the Add operator as an example. Check the data type consistency of the two inputs and whether the input data type is on the list of supported data types. The code implementation is as follows.
1 2 3 4 5
if input_data_type != input_data_type_y: raise RuntimeError( "the input_x and input_y should have the same data type.") check_tuple = ("float16", "float32", "int32") para_check.check_dtype(input_data_type, check_tuple, param_name="input_x")
You can use the common verification function in TBE Utils API to verify operator arguments.
- Obtain the shape and data type of the operator input tensor, which will be used to define the input tensor placeholder.
- Define the tensor placeholder for the input tensor based on the shape and dtype.See the following example:
data_input = tvm.placeholder(shape, name="data_input", dtype=dtype)
Use the placeholder API of TVM to create a placeholder for the input tensor, returning a tensor object. The data in this position is specified only during program running.
For an optional input, check whether the input is empty before inserting a placeholder. If the input is empty, no placeholder is required.
tensor_list described in Scheduling and Building is a list of tensor objects returned by calls to the tvm.placeholder API. Therefore, these objects must not be replaced in subsequent computation.
See the following example:
# Return the data_input placeholder. data_input = tvm.placeholder(shape, name='data', dtype=dtype) if dtype == "float16": # Cast the data type of data_input to float32 and assign a new value to data_input. data_input = dsl.cast_to(data_input, "float32") ... with tvm.target.cce(): schedule = dsl.auto_schedule(res) config = {"need_build":need_build, "name":kernel_name, "tensor_list":[data_input,res]} dsl.build(schedule,config)In the preceding code, after the data type is cast using data_input = dsl.cast_to(data_input, "float32"), the data_input object returned by the placeholder is overwritten, that is, the data_input object in the build configuration tensor_list is not consistent with that returned by the original placeholder API. In this case, the following error is reported during operator implementation build.

Therefore, you can redefine a tensor to store the input after data type cast for computation as follows:
data_input1 = dsl.cast_to(data_input, "float32")
Alternatively, as shown in 3, the computation is performed in the compute function. The input tensor returned by the placeholder API is moved to the compute function for computation through formal parameters. A new address is generated for computation, which can also prevent the tensor object returned by the placeholder API from being overwritten.
- Describe the compute process by using the compute function in the operator API definition function.
See the following example:
res = add_compute(data_x, data_y, output_z, kernel_name)
The input tensor is a placeholder tensor defined by tvm.placeholder. Other parameters are transparently passed by the operator definition function.
- Implement the compute function of the operator.
In the compute function, the operator compute process is completed. The compute process is implemented by developing code based on the TBE DSL API in Operator Analysis.
Quick Start shows the implementation of the Add operator. The following uses an advanced operator ReLU as an example to describe the implementation of the operator and some precautions of using DSL APIs.
Assume that the compute formula of the ReLU operator is as follows:

The compute implementation code is as follows.
@fusion_manager.register("relu") def relu_compute(x, y, kernel_name="relu"): inp_dtype = x.dtype # Obtain the input data type. shape = x.shape # Obtain the shape of the input data. # If the data type is float32 or int32, perform the vmax operation to avoid accuracy loss. if inp_dtype in ("float32", "int32"): tensor_zero = dsl.broadcast(tvm.const(CONST_ZERO, inp_dtype),shape) # The returned shape is consistent with that of the input data. Each element is 0 and is a tensor of the same data type of the input data. data_res = dsl.vmax(x, tensor_zero) # Obtain the larger value between x and tensor_zero. else: data_res = dsl.vrelu(x) # Perform the ReLU operation when the data type is float16 or int8. data_res = dsl.cast_to(data_res, inp_dtype) return data_res- If the tbe.dsl.vrelu( ) API is used, data types int8, uint8, int32, and float32 will be cast to float16, which brings accuracy loss for data types int32 and float32. To avoid this, when the input is of type int32 or float32, use the tbe.dsl.vmax( ) call to obtain the larger value between the input data and 0.
- In the TBE DSL, the vmax API requires that two input tensors have the same shape. Therefore, the tbe.dsl.broadcast API is used to broadcast the shapes of the input tensors to the same by using the larger value of each dimension of the two tensor shapes.
More tips for operator compute function implementation:
- If the data type of the input tensor is not float32, you can cast it to float32 for computation to improve the accuracy of the intermediate computation results. When the final result is output, the data type needs to be cast back to the original data type.
- When the compute process of the operator is complex, the internal functions can be extracted to maintain the simplicity and readability of each module.
Broadcast Operators
Broadcast operators require that the shapes of multiple inputs be the same. Therefore, broadcast operations are involved. When implementing computation functions, operators of this type need to call the broadcast_shapes or unify_broadcast_shapes API in tbe.common.utils.shape_util to compute the output shape. Then, they call the tbe.dsl.broadcast API to broadcast the shape of each related input to the output shape. Otherwise, an error similar to the following may occur:
'Compile operator failed, cause: Parameters check failed, detailed information: The lhs shape[(dim 0 0, dim 1 0)] must be equal to the rhs[(dim 0 0, dim 1 1)].'
vcmp and vsel Usage Restrictions
The output results of the tbe.dsl.vcmp and tbe.dsl.vsel APIs have accuracy drop with certain shape configurations. This is because the mode argument passed to the vcmp call is defaulted to bool, indicating that the data is stored in 8-bit mode. In this case, if the bool_storage_as_1bit parameter is not specified in the configuration file for implementing the operator definition function, the default value True is used, indicating that the data is stored in 1-bit mode, which results in a bit width conflict. Therefore, "bool_storage_as_1bit": False needs to be added to the configuration file as follows.
with tvm.target.cce():
schedule = dsl.auto_schedule(res)
config = {"name": kernel_name,
"tensor_list": [data_x, data_y, res],
"bool_storage_as_1bit": False}
dsl.build(schedule, config)