DSL算子泛化
如果开发者想支持“一类” 算子,要能适合任何合法的数据类型、数据形状,甚至适合多种昇腾AI处理器型号,这种场景,称之为算子的泛化。
算子泛化的基本思想就是:算子的数据类型和shape都从算子的输入中获取,然后进行基本的校验,然后根据算子的特殊要求,进行针对性的处理。
下面以一个相对复杂的Less为例,讲解如何进行算子的泛化。
算子分析
- 首先分析Less算子的算法原理。
Less算子的作用是对两个输入张量(A,B)进行逐元素比较大小,如果张量A的第i个元素比张量B对应位置的元素小,则结果张量的对应位置取“1”,否则结果张量的对应位置取“0”。
例如:
输入张量A:[1,2,3,4,5] 输入张量B:[5,4,3,2,1] 结果张量C:[1,1,0,0,0] 定义域:(-∞, +∞),值域:{0,1}
- 运算逻辑实现分析。
为简化分析,用shape为(0,)、类型为FP16的单个数来进行分析。
计算C=B-A的结果:- 如果C>0,最终结果取1。
- 如果C<=0,最终结果取0。
由于当前未能实现将C和数据0进行比较的接口,为实现以上计算逻辑,解决方法如下:- 将C与FP16的最小正值(即2-24)做比较,结果取较小的一个,命名为D。
- 如果C > 0,结果D = 2-24。
- 如果C <= 0,结果D = C。
- 将D与0进行比较,结果取较大的一个,命名为E。
- 如果D > 0,结果E = D,即E取值为2-24。
- 如果D <= 0,结果 E = 0。
- 通过以上两个步骤的分析,最终结果如下:
- 如果B > A,结果E的取值为最小正数。
- 如果B <= A,结果E的取值为0。
到此为止,已经可以正确获取B <= A的结果了,下一步还需要将D取最小正数时,将其变为1,方法为:将其与最大正数相乘。
例如FP16的最大正数为224,将FP16的最小正数与最大正数相乘,如下所示:
2-24 * 224 = 2(-24+24)= 20 = 1
结果为1,至此,已经能够获得所有场景的正确结果。
- 计算接口分析。
- 两数按元素相减,用vsub实现。
- 两数取小值,用vmin实现。
- 两数取大值,用vmax实现。
- 向量与标量相乘,用vmuls实现。
- 操作向量形状和数据类型,分别用broadcast接口和cast_to接口实现。
下面就可以进行Less算子的计算实现了。
计算实现
- 算子定义主函数
- 计算函数
- 比较函数
- 算子定义主函数的实现。首先看下算子定义主函数的完整代码。
def less(input_x, input_y, output_z, kernel_name="less"): shape_x = shape_util.scalar2tensor_one(input_x.get("shape")) shape_y = shape_util.scalar2tensor_one(input_y.get("shape")) para_check.check_shape(shape_x, param_name="input_x") para_check.check_shape(shape_y, param_name="input_y") check_list = ("float16", "float32", "int32", "int8", "uint8") input_dtype = input_x.get("dtype").lower() para_check.check_dtype(input_dtype, check_list, param_name="input_x") shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y") shape_x, shape_y = shape_util.refine_shapes_for_broadcast(shape_x, shape_y) data_x = tvm.placeholder(shape_x, dtype=input_dtype, name="data_x") data_y = tvm.placeholder(shape_y, dtype=input_dtype, name="data_y") res = less_compute(data_x, data_y, output_z, kernel_name="less") with tvm.target.cce(): sch = tbe.auto_schedule(res) config = {"print_ir": False, "name": kernel_name, "tensor_list": [data_x, data_y, res]} tbe.cce_build_code(sch, config)
- 算子定义函数声明,算子入参包含两个输入、一个输出,输入输出都是张量的元数据字典。
def less(input_x, input_y, output_z, kernel_name="less"):
以“input_x”为例,这个字典变量包含输入张量x的“shape”,“dtype”,“format”,“、ori_shape”,“ori_format”信息。
- 对入参进行校验。
- 检查形状。
shape_x = shape_util.scalar2tensor_one(input_x.get("shape")) shape_y = shape_util.scalar2tensor_one(input_y.get("shape")) para_check.check_shape(shape_x, param_name="input_x") para_check.check_shape(shape_y, param_name="input_y")
如果输入是标量,输入的shape信息是“0”,此时无法用一个张量来表示形状的信息,所以需要调用scalar2tensor_one把shape设置为[1]。如果输入本身就是张量,此接口不会做任何操作,直接将入参返回。
- 检查数据类型。
check_list = ("float16", "float32", "int32", "int8", "uint8") input_dtype = input_x.get("dtype").lower() para_check.check_dtype(input_dtype, check_list, param_name="input_x")
此处支持float16、float32、int32、int8、uint8几种数据类型。
- 整理形状。
shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y") shape_x, shape_y = shape_util.refine_shapes_for_broadcast(shape_x, shape_y)
在前面分析计算逻辑的时候,我们忽略了两个张量的形状信息,实际调用算子的时候,两个输入张量的形状不一定一样,此时想要两个张量逐元素比较的话,必须先保证两个张量能够广播成完成同样的形状。TBE提供了通用shape整理接口broadcast_shapes,功能是对两个shape做广播计算,详细功能如下:- 检查2个输入shape能否广播,如果不能,直接报错退出。
- 计算两个输入shape广播后的形状,即每根轴的最大值组成的shape信息,返回值为“shape_max”。
- 如果两个输入shape的维度数量都不同,那么对维度较小的shape进行高维补1的操作,并将补维之后的shape1、shape2也作为返回值返回。
之后的refine_shapes_for_broadcast接口是对连续同方向广播轴和连续非广播轴做融合操作,目的是提升性能,和计算逻辑关系不大。
- 检查形状。
- 进行完基本信息校验和形状整理后,就可以对输入张量进行占位。
data_x = tvm.placeholder(shape_x, dtype=input_dtype, name="data_x") data_y = tvm.placeholder(shape_y, dtype=input_dtype, name="data_y")
- 有了张量对象,就可以针对输入张量进行计算操作了。
res = less_compute(data_x, data_y, output_z, kernel_name="less")
less_compute计算函数,我们将在2进行详细讲解。
- 调度与编译。
with tvm.target.cce(): sch = tbe.auto_schedule(res) config = {"print_ir": False, "name": kernel_name, "tensor_list": [data_x, data_y, res]} tbe.cce_build_code(sch, config)
- 算子定义函数声明,算子入参包含两个输入、一个输出,输入输出都是张量的元数据字典。
- Less算子计算函数实现。
Less算子计算函数的实现代码如下:
def less_compute(input_x, input_y, output_z, kernel_name="less"): # 整理形状信息 shape_x = shape_util.shape_to_list(input_x.shape) shape_y = shape_util.shape_to_list(input_y.shape) shape_x, shape_y, shape = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y") # 获取AI处理器型号 soc_v = get_soc_spec("SOC_VERSION") # 获取数据类型 dtype = input_x.dtype # 针对不同的数据类型和AI处理器型号,分别计算最小值张量 if dtype in ("uint8", "int8"): input_x = dsl.cast_to(input_x, "float16") input_y = dsl.cast_to(input_y, "float16") dtype = "float16" if dtype == "float32": # minimum num of float32 2**(-126) data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype) elif dtype == "float16" and soc_v not in "Ascend910": # minimum num of float16 2**(-24) data_min = dsl.broadcast(tvm.const(2**(-24), dtype=dtype), shape, dtype) elif dtype == "float16" and soc_v in "Ascend910": input_x = dsl.cast_to(input_x, "float32") input_y = dsl.cast_to(input_y, "float32") dtype = "float32" data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype) elif dtype == "int32" and soc_v not in "Ascend910": data_min = dsl.broadcast(tvm.const(1, dtype=dtype), shape, dtype) else: input_x = dsl.cast_to(input_x, "float32") input_y = dsl.cast_to(input_y, "float32") dtype = "float32" data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype) # 对张量进行广播操作 input_x = dsl.broadcast(input_x, shape) input_y = dsl.broadcast(input_y, shape) # 进行比较计算 return _less_compare((input_x, input_y), shape, dtype, data_min)
- 首先对形状进行广播,主要是为了获得广播后的最终形状,为后续的获取最小值张量做准备。其实主函数中已经进行过形状的计算了,所以也可以直接将广播后的形状作为入参传给此计算函数。
- 因为不同的AI处理器型号上数据处理的方式不同,所以需要获取当前AI处理器类型。
- 获取输入的数据类型,为后续计算最小值张量做准备。
- 根据不同的数据类型和AI处理器型号,分别计算最小值张量。
最小值张量的形状,就是2.a计算的广播终点形状,此张量中每一个数值都是相同的,即对应数据类型的最小值,这个“最小值”,在整型情况下为“1”,在FP16的情况下是“2-24”,在FP32的情况下是“2-126”。
此处针对不同的AI处理器类型都做了适配,实际算子开发的时候,若只想适配自己的网络,仅开发使用的AI处理器版本的算子即可。
- 最后对两个输入张量进行真正的广播,前面的广播操作都是计算输入张量的“形状”应该广播成什么样例,并没有实际对张量进行广播。
至此,已经创建了最小值张量,也将两个输入张量广播成了形同的形状,下面就到比较计算了。
- 比较计算实现。
def _less_compare(data, shape, dtype, data_min): # 定义“零张量” data_zero = dsl.broadcast(tvm.const(0, dtype), shape, dtype) # 进行计算和比较操作 res_sub = dsl.vsub(data[1], data[0]) res_min = dsl.vmin(res_sub, data_min) res_max = dsl.vmax(res_min, data_zero) # 将res_max乘以数据类型的最大值,得到1 if dtype == "float32": # max num of float32 is 2**126 # but cce can only support 2**62, so use 62/62/2 to adaptor 126 res_mul1 = dsl.vmuls(res_max, tvm.const(2**62, dtype=dtype)) res_mul2 = dsl.vmuls(res_mul1, tvm.const(2**62, dtype=dtype)) res = dsl.vmuls(res_mul2, tvm.const(2**2, dtype=dtype)) elif dtype == "float16": # max num of float16 is 2**24 # but cce can only support 2**12, so use 12/12 to adaptor 24 res_mul1 = dsl.vmuls(res_max, tvm.const(2**12, dtype=dtype)) res = dsl.vmuls(res_mul1, tvm.const(2**12, dtype=dtype)) else: res = dsl.cast_to(res_max, "float16") # 为节省空间,进行数据类型转换 return dsl.cast_to(res, "uint8", True)
- 根据算子分析,结果D需要与0进行比较,所以我们还需要定义一个“零张量”,即匹配输入张量的形状,但每个元素都是0的张量。
- 计算与比较操作。
res_sub = dsl.vsub(data[1], data[0]) res_min = dsl.vmin(res_sub, data_min) res_max = dsl.vmax(res_min, data_zero)
- res_sub为data1与data0相减的值。
- res_min为res_sub与最小值张量的比较结果。
- res_max为取res_min与零张量的大值,即如果对应位置data1比data0大,则取“本数据类型的最小值”;如果对应位置data1不比data0大,则取“0”。
- 在res_max不为0的情况下,将res_max乘以其数据类型的最大值,从而得到1。最后一个else分支,仅做了数据类型的转换,因为数据在整形的情况下,最小值已经是1了,所以不需要再跟整形的最大值相乘。
- 最后为了节省结果张量的占用空间,将其转换为uint8类型,并返回。
总结
算子的泛化,主要考虑以下几个点:
- 对入参进行合理校验,将问题尽早暴露。
- 对于双输入的Element-wise类算子,要考虑到两个输入张量形状不同的情况。
- 对于不同的数据类型、不同的AI处理器型号,有时会有一些特殊的处理。不过如果开发者编写的算子只面向训练/推理中的一种场景的话,则不必过度设计。
- 中间计算时,从精度方面考虑,尽量用高精度进行计算。但不是绝对的,还要和性能进行权衡。
- 在给出最终结果时,考虑节省内存空间。