下载
EN
注册

DSL算子泛化

如果开发者想支持“一类” 算子,要能适合任何合法的数据类型、数据形状,甚至适合多种昇腾AI处理器型号,这种场景,称之为算子的泛化。

算子泛化的基本思想就是:算子的数据类型和shape都从算子的输入中获取,然后进行基本的校验,然后根据算子的特殊要求,进行针对性的处理。

下面以一个相对复杂的Less为例,讲解如何进行算子的泛化。

算子分析

  1. 首先分析Less算子的算法原理。

    Less算子的作用是对两个输入张量(A,B)进行逐元素比较大小,如果张量A的第i个元素比张量B对应位置的元素小,则结果张量的对应位置取“1”,否则结果张量的对应位置取“0”。

    例如:

    输入张量A:[1,2,3,4,5]
    输入张量B:[5,4,3,2,1]
    结果张量C:[1,1,0,0,0]
    定义域:(-∞, +∞),值域:{0,1}
  2. 运算逻辑实现分析。

    为简化分析,用shape为(0,)、类型为FP16的单个数来进行分析。

    计算C=B-A的结果:
    • 如果C>0,最终结果取1。
    • 如果C<=0,最终结果取0。
    由于当前未能实现将C和数据0进行比较的接口,为实现以上计算逻辑,解决方法如下:
    1. 将C与FP16的最小正值(即2-24)做比较,结果取较小的一个,命名为D。
      • 如果C > 0,结果D = 2-24
      • 如果C <= 0,结果D = C。
    2. 将D与0进行比较,结果取较大的一个,命名为E。
      • 如果D > 0,结果E = D,即E取值为2-24
      • 如果D <= 0,结果 E = 0。
    3. 通过以上两个步骤的分析,最终结果如下:
      • 如果B > A,结果E的取值为最小正数。
      • 如果B <= A,结果E的取值为0。

    到此为止,已经可以正确获取B <= A的结果了,下一步还需要将D取最小正数时,将其变为1,方法为:将其与最大正数相乘。

    例如FP16的最大正数为224,将FP16的最小正数与最大正数相乘,如下所示:

    2-24 * 224 = 2-24+24= 20 = 1

    结果为1,至此,已经能够获得所有场景的正确结果。

  3. 计算接口分析。
    • 两数按元素相减,用vsub实现。
    • 两数取小值,用vmin实现。
    • 两数取大值,用vmax实现。
    • 向量与标量相乘,用vmuls实现。
    • 操作向量形状和数据类型,分别用broadcast接口和cast_to接口实现。

下面就可以进行Less算子的计算实现了。

计算实现

将Less算子的实现分为如下3个函数:
  • 算子定义主函数
  • 计算函数
  • 比较函数
  1. 算子定义主函数的实现。
    首先看下算子定义主函数的完整代码。
    def less(input_x, input_y, output_z, kernel_name="less"):
        shape_x = shape_util.scalar2tensor_one(input_x.get("shape"))
        shape_y = shape_util.scalar2tensor_one(input_y.get("shape"))
        para_check.check_shape(shape_x, param_name="input_x")
        para_check.check_shape(shape_y, param_name="input_y")
    
        check_list = ("float16", "float32", "int32", "int8", "uint8")
        input_dtype = input_x.get("dtype").lower()
        para_check.check_dtype(input_dtype, check_list, param_name="input_x")
    
        shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y")
    
        shape_x, shape_y = shape_util.refine_shapes_for_broadcast(shape_x, shape_y)                                                 
        data_x = tvm.placeholder(shape_x, dtype=input_dtype, name="data_x")
        data_y = tvm.placeholder(shape_y, dtype=input_dtype, name="data_y")
    
        res = less_compute(data_x, data_y, output_z, kernel_name="less")
        with tvm.target.cce():
            sch = tbe.auto_schedule(res)
    
        config = {"print_ir": False,
                  "name": kernel_name,
                  "tensor_list": [data_x, data_y, res]}
        tbe.cce_build_code(sch, config)
    1. 算子定义函数声明,算子入参包含两个输入、一个输出,输入输出都是张量的元数据字典。
      def less(input_x, input_y, output_z, kernel_name="less"):

      以“input_x”为例,这个字典变量包含输入张量x的“shape”,“dtype”,“format”,“、ori_shape”,“ori_format”信息。

    2. 对入参进行校验。
      • 检查形状。
        shape_x = shape_util.scalar2tensor_one(input_x.get("shape"))
        shape_y = shape_util.scalar2tensor_one(input_y.get("shape"))
        para_check.check_shape(shape_x, param_name="input_x")
        para_check.check_shape(shape_y, param_name="input_y")

        如果输入是标量,输入的shape信息是“0”,此时无法用一个张量来表示形状的信息,所以需要调用scalar2tensor_one把shape设置为[1]。如果输入本身就是张量,此接口不会做任何操作,直接将入参返回。

      • 检查数据类型。
        check_list = ("float16", "float32", "int32", "int8", "uint8")
        input_dtype = input_x.get("dtype").lower()
        para_check.check_dtype(input_dtype, check_list, param_name="input_x")

        此处支持float16、float32、int32、int8、uint8几种数据类型。

      • 整理形状。
        shape_x, shape_y, shape_max = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y")
        shape_x, shape_y = shape_util.refine_shapes_for_broadcast(shape_x, shape_y) 
        在前面分析计算逻辑的时候,我们忽略了两个张量的形状信息,实际调用算子的时候,两个输入张量的形状不一定一样,此时想要两个张量逐元素比较的话,必须先保证两个张量能够广播成完成同样的形状。TBE提供了通用shape整理接口broadcast_shapes,功能是对两个shape做广播计算,详细功能如下:
        • 检查2个输入shape能否广播,如果不能,直接报错退出。
        • 计算两个输入shape广播后的形状,即每根轴的最大值组成的shape信息,返回值为“shape_max”。
        • 如果两个输入shape的维度数量都不同,那么对维度较小的shape进行高维补1的操作,并将补维之后的shape1、shape2也作为返回值返回。

        之后的refine_shapes_for_broadcast接口是对连续同方向广播轴和连续非广播轴做融合操作,目的是提升性能,和计算逻辑关系不大。

    3. 进行完基本信息校验和形状整理后,就可以对输入张量进行占位。
      data_x = tvm.placeholder(shape_x, dtype=input_dtype, name="data_x")
      data_y = tvm.placeholder(shape_y, dtype=input_dtype, name="data_y")
    4. 有了张量对象,就可以针对输入张量进行计算操作了。
      res = less_compute(data_x, data_y, output_z, kernel_name="less")

      less_compute计算函数,我们将在2进行详细讲解。

    5. 调度与编译。
      with tvm.target.cce():
          sch = tbe.auto_schedule(res)
      
      config = {"print_ir": False,
                "name": kernel_name,
                "tensor_list": [data_x, data_y, res]}
      tbe.cce_build_code(sch, config)
  2. Less算子计算函数实现。

    Less算子计算函数的实现代码如下:

    def less_compute(input_x, input_y, output_z, kernel_name="less"):
        # 整理形状信息
        shape_x = shape_util.shape_to_list(input_x.shape)
        shape_y = shape_util.shape_to_list(input_y.shape)
        shape_x, shape_y, shape = shape_util.broadcast_shapes(shape_x, shape_y, param_name_input1="input_x", param_name_input2="input_y")
        # 获取AI处理器型号
        soc_v = get_soc_spec("SOC_VERSION")
        # 获取数据类型
        dtype = input_x.dtype
        # 针对不同的数据类型和AI处理器型号,分别计算最小值张量
        if dtype in ("uint8", "int8"):
            input_x = dsl.cast_to(input_x, "float16")
            input_y = dsl.cast_to(input_y, "float16")
            dtype = "float16"
        if dtype == "float32":
            # minimum num of float32 2**(-126)
            data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype)
        elif dtype == "float16" and soc_v not in "Ascend910":
            # minimum num of float16 2**(-24)
            data_min = dsl.broadcast(tvm.const(2**(-24), dtype=dtype), shape, dtype)
        elif dtype == "float16" and soc_v in "Ascend910":
            input_x = dsl.cast_to(input_x, "float32")
            input_y = dsl.cast_to(input_y, "float32")
            dtype = "float32"
            data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype)
        elif dtype == "int32" and soc_v not in "Ascend910":
            data_min = dsl.broadcast(tvm.const(1, dtype=dtype), shape, dtype)
        else:
            input_x = dsl.cast_to(input_x, "float32")
            input_y = dsl.cast_to(input_y, "float32")
            dtype = "float32"
            data_min = dsl.broadcast(tvm.const(2**(-126), dtype=dtype), shape, dtype)
        #  对张量进行广播操作
        input_x = dsl.broadcast(input_x, shape)
        input_y = dsl.broadcast(input_y, shape)
        # 进行比较计算
        return _less_compare((input_x, input_y), shape, dtype, data_min)
    1. 首先对形状进行广播,主要是为了获得广播后的最终形状,为后续的获取最小值张量做准备。其实主函数中已经进行过形状的计算了,所以也可以直接将广播后的形状作为入参传给此计算函数。
    2. 因为不同的AI处理器型号上数据处理的方式不同,所以需要获取当前AI处理器类型
    3. 获取输入的数据类型,为后续计算最小值张量做准备。
    4. 根据不同的数据类型和AI处理器型号,分别计算最小值张量。

      最小值张量的形状,就是2.a计算的广播终点形状,此张量中每一个数值都是相同的,即对应数据类型的最小值,这个“最小值”,在整型情况下为“1”,在FP16的情况下是“2-24”,在FP32的情况下是“2-126”。

      此处针对不同的AI处理器类型都做了适配,实际算子开发的时候,若只想适配自己的网络,仅开发使用的AI处理器版本的算子即可。

    5. 最后对两个输入张量进行真正的广播,前面的广播操作都是计算输入张量的“形状”应该广播成什么样例,并没有实际对张量进行广播。

    至此,已经创建了最小值张量,也将两个输入张量广播成了形同的形状,下面就到比较计算了。

  3. 比较计算实现。
    def _less_compare(data, shape, dtype, data_min):
        # 定义“零张量”
        data_zero = dsl.broadcast(tvm.const(0, dtype), shape, dtype)
        # 进行计算和比较操作
        res_sub = dsl.vsub(data[1], data[0])
        res_min = dsl.vmin(res_sub, data_min)
        res_max = dsl.vmax(res_min, data_zero)
        # 将res_max乘以数据类型的最大值,得到1
        if dtype == "float32":
            # max num of float32 is 2**126
            # but cce can only support 2**62, so use 62/62/2 to adaptor 126
            res_mul1 = dsl.vmuls(res_max, tvm.const(2**62, dtype=dtype))
            res_mul2 = dsl.vmuls(res_mul1, tvm.const(2**62, dtype=dtype))
            res = dsl.vmuls(res_mul2, tvm.const(2**2, dtype=dtype))
        elif dtype == "float16":
            # max num of float16 is 2**24
            # but cce can only support 2**12, so use 12/12 to adaptor 24
            res_mul1 = dsl.vmuls(res_max, tvm.const(2**12, dtype=dtype))
            res = dsl.vmuls(res_mul1, tvm.const(2**12, dtype=dtype))
        else:
            res = dsl.cast_to(res_max, "float16")
        # 为节省空间,进行数据类型转换
        return dsl.cast_to(res, "uint8", True)
    1. 根据算子分析,结果D需要与0进行比较,所以我们还需要定义一个“零张量”,即匹配输入张量的形状,但每个元素都是0的张量。
    2. 计算与比较操作。
          res_sub = dsl.vsub(data[1], data[0])
          res_min = dsl.vmin(res_sub, data_min)
          res_max = dsl.vmax(res_min, data_zero)
      • res_sub为data1与data0相减的值。
      • res_min为res_sub与最小值张量的比较结果。
      • res_max为取res_min与零张量的大值,即如果对应位置data1比data0大,则取“本数据类型的最小值”;如果对应位置data1不比data0大,则取“0”。
    3. 在res_max不为0的情况下,将res_max乘以其数据类型的最大值,从而得到1。最后一个else分支,仅做了数据类型的转换,因为数据在整形的情况下,最小值已经是1了,所以不需要再跟整形的最大值相乘。
    4. 最后为了节省结果张量的占用空间,将其转换为uint8类型,并返回。

总结

算子的泛化,主要考虑以下几个点:

  1. 对入参进行合理校验,将问题尽早暴露。
  2. 对于双输入的Element-wise类算子,要考虑到两个输入张量形状不同的情况。
  3. 对于不同的数据类型、不同的AI处理器型号,有时会有一些特殊的处理。不过如果开发者编写的算子只面向训练/推理中的一种场景的话,则不必过度设计。
  4. 中间计算时,从精度方面考虑,尽量用高精度进行计算。但不是绝对的,还要和性能进行权衡。
  5. 在给出最终结果时,考虑节省内存空间。