DSL Operator Generalization
Operator generalization refers to generalizing operators to support any valid data type, data shape, and even multiple versions of Ascend AI Processor.
The basic principle of operator generalization is obtaining the data type and shape of an operator from its input for basic verification and performing specific processing to meet the requirements of the operator.
The following uses a Less operator as an example to describe how to generalize an operator in a relatively complex scenario.
Operator Analysis
- Analyze the algorithm principle of the Less operator.
The Less operator compares two input tensors (A and B) element-wise. If the ith element of tensor A is smaller than that of tensor B, the Less operator returns 1 as the corresponding element of the result tensor; returns 0 otherwise.
See the following example:
Input tensor A: [1, 2, 3, 4, 5] Input tensor B: [5, 4, 3, 2, 1] Result tensor C: [1, 1, 0, 0, 0] Domain : (–∞, +∞). Range: {0, 1} - Analyze the implementation of the compute logic.
To explain it clearly, here uses float16 data with shape of (0,) as an example.
Compute the result (C = B – A):- If C > 0, the final result is 1.
- If C <= 0, the final result is 0.
Currently, no API can compare the tensor C with 0 directly. To implement the compute logic, perform the following steps:- Compare C with the minimum positive value of float16 (2–24), select the smaller value, and name it D.
- If C > 0, D = 2–24.
- If C <= 0, D = C.
- Compare D with 0, select the larger one, and name it E.
- If D > 0, E = D = 2–24.
- If D <= 0, E = 0.
- Based on the preceding analysis, the final result is as follows:
- If B > A, the value of E is the minimum positive number.
- If B <= A, E = 0.
Till now, we can know whether B <= A. Then, multiply the minimum positive number of D by its maximum positive number, obtaining 1.
For example, multiply the minimum positive number of float16 by the maximum positive number (224) as follows:
2–24 * 224 = 2(–24+24) = 20 = 1
The result is 1. Now, we have obtained the correct results of all scenarios.
- Analyze the compute API.
- vsub: returns the subtraction of two inputs element-wise.
- vmin: returns the minimum values of two inputs element-wise.
- vmax: returns the maximum values of two inputs element-wise.
- vmuls: returns the multiplication of a vector and a scalar.
- broadcast: reshapes a vector.
- cast_to: changes the data type of a vector.
Then, you can perform the compute implementation of the Less operator.
Compute Implementation
- Main function of the operator definition
- Compute function
- Comparison function
- Implement the main function of the operator definition.The following is the code in totality.
def less(input_x, input_y, output_z, kernel_name="less"): shape_x = shape_util.scalar2tensor_one(input_x.get("shape")) shape_y = shape_util.scalar2tensor_one(input_y.get("shape")) para_check.check_shape(shape_x, param_name="input_x") para_check.check_shape(shape_y, param_name="input_y") check_list = ("float16", "float32", "int32", "int8", "uint8") input_dtype = input_x.get("dtype").lower() para_check.check_dtype(input_dtype, check_list, param_name="input_x") shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y") shape_x, shape_y = shape_util.refine_shapes_for_broadcast(shape_x, shape_y) data_x = tvm.placeholder(shape_x, dtype=input_dtype, name="data_x") data_y = tvm.placeholder(shape_y, dtype=input_dtype, name="data_y") res = less_compute(data_x, data_y, output_z, kernel_name="less") with tvm.target.cce(): sch = tbe.auto_schedule(res) config = {"print_ir": False, "name": kernel_name, "tensor_list": [data_x, data_y, res]} tbe.cce_build_code(sch, config)- In the declaration of operator definition function, the operator arguments include two inputs and one output, both of which are tensor metadata of the dictionary type.
def less(input_x, input_y, output_z, kernel_name="less"):
Take input_x as an example. This dictionary variable contains the shape, dtype, format, ori_shape, and ori_format information of the input tensor x.
- Verify the arguments.
- Check the shape.
shape_x = shape_util.scalar2tensor_one(input_x.get("shape")) shape_y = shape_util.scalar2tensor_one(input_y.get("shape")) para_check.check_shape(shape_x, param_name="input_x") para_check.check_shape(shape_y, param_name="input_y")If the input is a scalar, the input shape information cannot be represented as a tensor and is deemed as 0. In this case, call scalar2tensor_one to set shape to [1]. If the input is a tensor, this API returns the argument directly.
- Check the data type.
check_list = ("float16", "float32", "int32", "int8", "uint8") input_dtype = input_x.get("dtype").lower() para_check.check_dtype(input_dtype, check_list, param_name="input_x")The supported data types include float16, float32, int32, int8, and uint8.
- Perform the reshape operation.
shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y") shape_x, shape_y = shape_util.refine_shapes_for_broadcast(shape_x, shape_y)
During the analysis of the compute logic, the shape of the two tensors was not analyzed. Before running the operator to compare two tensors element-wise, ensure these two tensors have compatible shape or can be compatible after broadcast. TBE provides a common reshape API broadcast_shapes to broadcast shapes. The detailed functions are as follows:- Check whether the two inputs are broadcastable. If they are not, an error is reported and no operation is performed.
- Compute the target shape (shape_max), which is composed of the larger value of each axis.
- If the numbers of dimensions of two input shapes are different, 1 is padded to the low-rank tensor, which starts with the trailing dimensions and works its way forward. Padded shape1 and shape2 are returned.
Then, refine_shapes_for_broadcast API performs fusion on adjacent broadcast axes with identical broadcast direction and adjacent unbroadcast axes. The purpose is to improve the performance rather than modify the compute logic.
- Check the shape.
- After the basic information verification and reshape operations, you can insert placeholders for the input tensors.
data_x = tvm.placeholder(shape_x, dtype=input_dtype, name="data_x") data_y = tvm.placeholder(shape_y, dtype=input_dtype, name="data_y")
- After the tensor object is determined, you can perform computation on the input tensors.
res = less_compute(data_x, data_y, output_z, kernel_name="less")
For details about the less_compute function, see 2.
- Perform scheduling and building.
with tvm.target.cce(): sch = tbe.auto_schedule(res) config = {"print_ir": False, "name": kernel_name, "tensor_list": [data_x, data_y, res]} tbe.cce_build_code(sch, config)
- In the declaration of operator definition function, the operator arguments include two inputs and one output, both of which are tensor metadata of the dictionary type.
- Implement the compute function of the Less operator.
The code in totality is as follows:
def less_compute(input_x, input_y, output_z, kernel_name="less"): # Organize shape information. shape_x = shape_util.shape_to_list(input_x.shape) shape_y = shape_util.shape_to_list(input_y.shape) shape_x, shape_y, shape = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y") # Obtain the AI Processor model. soc_v = get_soc_spec("SOC_VERSION") # Obtain the data type. dtype = input_x.dtype # Compute the minimum tensor based on different data types and AI Processor models. if dtype in ("uint8", "int8"): input_x = dsl.cast_to(input_x, "float16") input_y = dsl.cast_to(input_y, "float16") dtype = "float16" if dtype == "float32": # minimum num of float32 2**(-126) data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype) elif dtype == "float16" and soc_v not in "Ascend910": # minimum num of float16 2**(-24) data_min = dsl.broadcast(tvm.const(2**(-24), dtype=dtype), shape, dtype) elif dtype == "float16" and soc_v in "Ascend910": input_x = dsl.cast_to(input_x, "float32") input_y = dsl.cast_to(input_y, "float32") dtype = "float32" data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype) elif dtype == "int32" and soc_v not in "Ascend910": data_min = dsl.broadcast(tvm.const(1, dtype=dtype), shape, dtype) else: input_x = dsl.cast_to(input_x, "float32") input_y = dsl.cast_to(input_y, "float32") dtype = "float32" data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype) # Broadcast the tensor. input_x = dsl.broadcast(input_x, shape) input_y = dsl.broadcast(input_y, shape) # Perform the comparison. return _less_compare((input_x, input_y), shape, dtype, data_min)- Compute the result shape according to the broadcast rule for the subsequent computation. Alternatively, the shape obtained from the implementation of the main function can be directly passed to the compute function as an argument.
- Obtain the current AI Processor model because the data processing mode varies depending on the AI Processor model.
- Obtain the input data type for subsequent computation.
- Compute the minimum tensor based on different data types and AI Processor models.
The shape of the minimum tensor is the result shape computed in 2.a. Each element value in this tensor is the same, that is, the minimum value of the corresponding data type. For example, the minimum value is 1 for the integer type; 2–24 for the float16 type; 2–126 for the float32 type.
In the given example, adaptation is performed for different AI Processor models. In the actual practice, you only need to develop the operator that works with the AI Processor that you use.
- Then, broadcast the two input tensors to get the result tensor.
At this point, the minimum tensor is created, and the two input tensors are broadcast into compatible shapes. The next step is the comparison implementation.
- Implement the comparison operation.
def _less_compare(data, shape, dtype, data_min): # Define the zero tensor. data_zero = dsl.broadcast(tvm.const(0, dtype), shape, dtype) # Perform computation and comparison. res_sub = dsl.vsub(data[1], data[0]) res_min = dsl.vmin(res_sub, data_min) res_max = dsl.vmax(res_min, data_zero) # Multiply res_max by the maximum value of the corresponding data type to obtain 1. if dtype == "float32": # max num of float32 is 2**126 # but cce can only support 2**62, so use 62/62/2 to adapter 126 res_mul1 = dsl.vmuls(res_max, tvm.const(2**62, dtype=dtype)) res_mul2 = dsl.vmuls(res_mul1, tvm.const(2**62, dtype=dtype)) res = dsl.vmuls(res_mul2, tvm.const(2**2, dtype=dtype)) elif dtype == "float16": # max num of float16 is 2**24 # but cce can only support 2**12, so use 12/12 to adapter 24 res_mul1 = dsl.vmuls(res_max, tvm.const(2**12, dtype=dtype)) res = dsl.vmuls(res_mul1, tvm.const(2**12, dtype=dtype)) else: res = dsl.cast_to(res_max, "float16") # Cast the data type to minimize memory occupation. return dsl.cast_to(res, "uint8", True)- According to the operator analysis, the tensor D needs to be compared with 0. Therefore, we need to define a zero tensor, that is, a tensor with all elements set to 0 to match the shape of the input tensor.
- Perform the computation and comparison.
res_sub = dsl.vsub(data[1], data[0]) res_min = dsl.vmin(res_sub, data_min) res_max = dsl.vmax(res_min, data_zero)- res_sub is the subtraction of data[0] and data[1].
- res_min is the comparison result of res_sub and the minimum tensor.
- res_max is the result tensor obtained from element-wise comparison of res_min and the zero tensor. If an element in data[1] is greater than that in the same position of data[0], the minimum value of the data type is used in the result tensor. Otherwise, 0 is used in the result tensor.
- If the value of res_max is not 0, multiply it by the maximum value of its data type to obtain 1. In the last else statement, only the data type is cast. When the data type is integer, the minimum value is 1. In this scenario, res_max does not need to be multiplied by the maximum value of the integer.
- To save the space occupied by the result tensor, cast the result tensor to the uint8 type and return the result.
Wrap-up
Note the following points for operator generalization:
- Verify the arguments to expose problems as soon as possible.
- Check that two input tensors of an operator have compatible shapes if the operator performs element-wise operation.
- Some data types and AI Processor models require specific processing. However, if your operator is used for the training or inference purpose, extra modification is not needed.
- During intermediate computation, high precision is recommended. In some cases, performance also needs to be considered.
- Before returning the result, try to minimize the memory occupation.