DSL功能调试

功能介绍

使用DSL方式开发TBE算子的用户只需要关注算法逻辑,直接调用TBE DSL提供的auto_schedule接口即可完成算子的自动调度。因此,算子如果运行出错,用户仅需要验证算法逻辑描述是否正确即可,如果确认算法描述逻辑正确,但算子运行依然出错,则可把错误归结为TBE内部错误,则您可在Ascend开源仓中通过issue进行问题反馈。

若算子运行出错,开发者可参见本章节描述的方法在CPU中进行DSL算子的功能调试,验证算子的算法逻辑是否正确。

使用方法

TBE DSL提供了在CPU上验证算子功能正确性的调试框架,方便开发者快速验证算子功能的正确性,具体流程如下:

图1 DSL功能调试流程
  1. 进入调试模式。

    调用tbe.common.testing.debug接口,并配合python with语句进入调试模式。

  2. 调用tbe.common.testing.get_ctx接口,选择CPU作为DSL算子的运行平台,并获取算子运行的上下文。
  3. 使用placeholder接口对输入Tensor进行占位。
  4. 进行DSL算子的计算逻辑实现。
  5. 中间Tensor数据验证。

    开发者可调用tbe.common.testing.print_tensor接口,将中间Tensor的数据存入文件;并可使用numpy定义输入golden数据,调用tbe.common.testing.assert_allclose接口进行中间Tensor的数据校验。

  6. 调用TVM的create_schedule接口,为算子创建调度实例对象。
  7. 调用tbe.common.testing.build接口,编译生成在CPU上运行的DSL算子。
  8. 调用tbe.common.testing.run接口执行算子。
  9. 验证输出数据的正确性。

    至此,DSL算子的调试代码编写完成。开发者可编写入口函数,调用算子接口,进行算子的功能调试。

调试示例

from tbe import tvm
from tbe import dsl
from tbe.common.utils import para_check
from tbe.common.utils import shape_util
# 引入testing模块相关接口
from tbe.common.testing.testing import *
import numpy as np

@para_check.check_input_type(dict, dict, dict, str)
def addtest(input_a, input_b, output_d, kernel_name="addtest"):
    # 进入DSL调试模式
    with debug(): 
        # 选择CPU作为DSL的运行平台
        ctx = get_ctx()
        
        # DSL算子实现的计算逻辑
        shape_a = shape_util.scalar2tensor_one(input_a.get("shape"))
        shape_b = shape_util.scalar2tensor_one(input_b.get("shape"))
        data_type = input_a.get("dtype").lower()
        # 调用TVM的placeholder接口对输入tensor进行占位,并返回一个tensor对象
        data_a = tvm.placeholder(shape_a, name="data_1", dtype=data_type)
        data_b = tvm.placeholder(shape_b, name="data_2", dtype=data_type)
        # 调用DSL计算接口实现data_a + data_b
        data_c = dsl.vadd(data_a, data_b)
	
        # 中间Tensor数据验证
        # 打印中间tensor data_c并存入文件samplefile.txt
        sample = open('samplefile.txt', 'w')
        print_tensor(data_c, ofile=sample)  # 支持任意tensor

        # 使用numpy定义输入golden数据大小
        a = tvm.nd.array(np.random.uniform(size=shape_a).astype(data_type), ctx)
        b = tvm.nd.array(np.random.uniform(size=shape_b).astype(data_type), ctx)

        # 检查中间tensor data_c的值是否正确
        assert_allclose(data_c, desired=a.asnumpy() + b.asnumpy(), tol=[1e-7, 1e-7])  # 支持任意tensor,tol=[rtol, atol]
			
					
	# 继续自定义DSL的逻辑撰写,调用DSL接口实现:data_d = data_c + data_b
        data_d = dsl.vadd(data_c, data_b)
        # 调用TVM的create_schedule接口,为算子创建调度实例对象,入参为输出tensor的OP列表。
        s = tvm.create_schedule(data_d.op)

        # 编译生成算子,data_a,data_b,data_d是占位的输入输出列表,AddTest是我们自定义算子的名称
        build(s, [data_a, data_b, data_d], name="AddTest")           

        # 使用numpy给输出初始化为全0
        d = tvm.nd.array(np.zeros(shape_a, dtype=data_type), ctx)

        # 执行算子,将a,b,d按顺序代入我们编译出来的DSL算子AddTest
        run(a, b, d)  # AddTest(a, b, d)

        # 将输出d的值打印出来,并预期结果进行比较,看是否相符
        print("d:", d.asnumpy)
        tvm.testing.assert_allclose(d.asnumpy(), a.asnumpy() + b.asnumpy() + b.asnumpy())
        print("The actual output is the same as the expected output. ")

# 编写入口函数,调用addtest函数
if __name__ == "__main__":
    input_output_dict = {"shape": (5, 6, 7),"format": "ND","ori_shape": (5, 6, 7),"ori_format": "ND", "dtype":"float32"}
    addtest(input_output_dict, input_output_dict, input_output_dict, kernel_name="addtest")

在屏幕上的输出为:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
======================== debug enter =======================
Tensor add_0 is saved to file samplefile.txt.
d: <bound method NDArrayBase.asnumpy of <tvm.NDArray shape=(5, 6, 7), cpu(0)>
array([[[1.9765025 , 1.3063627 , 1.5455024 , 1.8914425 , 0.75562   ,
         1.4493686 , 1.3199302 ],
        [1.3553491 , 2.2258139 , 1.945476  , 2.168409  , 1.1769401 ,
         1.9729862 , 0.8166871 ],
        [1.3076797 , 2.7369254 , 2.0644572 , 1.6179233 , 1.8307312 ,
         1.3352482 , 1.2801952 ],
        [0.7681772 , 2.4768236 , 1.1291559 , 2.0766551 , 0.41154432,
         1.7451818 , 0.66318583],
        [1.9156246 , 1.9579055 , 1.4429919 , 2.7537508 , 1.6746674 ,
         1.8326821 , 0.78217393],
        [2.7599869 , 1.3651894 , 1.3471396 , 1.1100632 , 1.7197586 ,
         0.53973526, 1.6119102 ]],

       [[2.8537135 , 1.846461  , 1.3941739 , 1.8292844 , 2.6258516 ,
         1.7189883 , 1.2285931 ],
        [2.5283246 , 2.784803  , 1.125393  , 1.5372462 , 1.3298836 ,
         1.9428086 , 1.7388952 ],
        [1.3924108 , 2.4534335 , 1.0272032 , 1.0579281 , 2.3483698 ,
         1.7413938 , 1.645398  ],
        [1.2009311 , 2.1751819 , 1.864238  , 0.44047868, 2.2192657 ,
         2.4648168 , 1.0090718 ],
        [0.6559963 , 2.546305  , 1.0522286 , 1.3295491 , 0.7785244 ,
         1.6488949 , 0.94041246],
        [2.3107693 , 1.7979652 , 1.409431  , 1.8151562 , 1.3728262 ,
         1.2684295 , 1.9496129 ]],

       [[1.5620731 , 1.5133004 , 1.4744021 , 0.44049683, 2.3037994 ,
         2.4848747 , 1.4918609 ],
        [1.0406991 , 2.3053305 , 1.7351038 , 0.84725004, 1.4539167 ,
         0.6062449 , 1.2351246 ],
        [1.9449794 , 2.7748094 , 0.6507897 , 1.2971216 , 1.4513849 ,
         2.1400795 , 2.5021868 ],
        [1.2449852 , 2.0492396 , 2.3701015 , 1.967829  , 1.3607856 ,
         1.3443347 , 1.8483088 ],
        [0.9416573 , 2.5379946 , 0.9037132 , 0.98067534, 2.7267451 ,
         1.412893  , 1.4104416 ],
        [2.0888834 , 1.7988434 , 2.274027  , 0.65598696, 1.6114297 ,
         2.8206928 , 0.78396904]],

       [[0.0070636 , 0.85229254, 1.0938224 , 1.9194655 , 1.0623909 ,
         2.8294506 , 0.90953755],
        [0.7315122 , 1.7403553 , 1.8028924 , 1.9337978 , 0.89290977,
         1.0474195 , 1.488183  ],
        [1.1293429 , 2.1118681 , 0.82159084, 2.7052598 , 1.1781758 ,
         1.1849467 , 1.4357327 ],
        [2.455081  , 2.360027  , 1.4727659 , 0.9091361 , 0.26764303,
         2.060249  , 1.3079873 ],
        [1.2675905 , 0.7459479 , 1.302169  , 1.0312747 , 2.5122142 ,
         2.6101115 , 1.6051782 ],
        [1.5052669 , 0.81929123, 2.428927  , 2.7529616 , 1.1843456 ,
         2.1128232 , 1.758671  ]],

       [[0.9602847 , 2.1775813 , 2.435278  , 0.57753366, 1.7766705 ,
         0.62516516, 1.5479352 ],
        [1.3457404 , 1.7887821 , 1.3990772 , 1.6418769 , 1.8924582 ,
         1.8680563 , 1.2213963 ],
        [0.4563806 , 0.4444568 , 0.9639152 , 1.0983156 , 1.464408  ,
         1.4074727 , 2.0831294 ],
        [1.9724479 , 2.2760797 , 0.72440803, 1.5530026 , 1.3526099 ,
         2.628211  , 1.1004283 ],
        [0.7632239 , 2.1213415 , 1.8780222 , 0.5869153 , 1.9058757 ,
         0.31886303, 0.9498419 ],
        [1.824633  , 2.0133598 , 2.1071515 , 1.1702616 , 1.3623409 ,
         0.6014643 , 1.2018499 ]]], dtype=float32)>
The actual output is the same as the expected output.
======================== debug exit ========================

tvm.testing.assert_allclose( )接口无异常抛出,输出数据与期望数据校验一致,算子的功能正常。