昇腾社区首页
EN
注册

检测Triton算子

前提条件

  • 参考LINK,完成Triton及Triton-Ascend插件的安装和配置。
  • 自备Triton算子实现文件。
    若用户尚未准备Triton算子,可参考以下示例。本节将基于此示例来说明Triton算子的检测流程。
    # file name: sample.py
    import triton
    import triton.language as tl
    import torch
    
    def torch_pointwise(x0, x1):
        res = x0 + x1
        return res
    
    
    @triton.jit
    def triton_add(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr):
        offset = tl.program_id(0) * XBLOCK
        base1 = tl.arange(0, XBLOCK_SUB)
        loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB
        for loop1 in range(loops1):
            x0 = offset + (loop1 * XBLOCK_SUB) + base1
            tmp0 = tl.load(in_ptr0 + (x0), None)
            tmp1 = tl.load(in_ptr1 + (x0), None)
            tmp2 = tmp0 + tmp1
            tl.store(out_ptr0 + (x0), tmp2, None)
    
    
    def test_case(dtype, shape, ncore, xblock, xblock_sub):
        x0 = torch.randn(shape, dtype=dtype).npu()
        x1 = torch.randn(shape, dtype=dtype).npu()
        y_ref = torch_pointwise(x0, x1)
        y_cal = torch.zeros(shape, dtype=dtype).npu()
        triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub)
        print("Pass" if torch.equal(y_ref, y_cal) else "Failed")
    
    
    if __name__ == "__main__":
        test_case(torch.float32, (2, 4096, 8), 2, 32768, 1024)

操作步骤

  1. 请参考Triton算子调用场景准备,完成使用前准备。
  2. 关闭内存池。
    样例中使用PyTorch创建Tensor,PyTorch框架内默认使用内存池的方式管理GM内存,会对内存检测产生干扰。因此,在检测前需要额外设置如下环境变量关闭内存池,以保证检测结果准确。
    export PYTORCH_NO_NPU_MEMORY_CACHING=1
  3. 在Triton算子中构造一个非法读写的场景,将第一次load的内存向右偏移100个元素,此时会导致load在GM内存上发生非法读。
    def triton_add(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr):
        offset = tl.program_id(0) * XBLOCK
        base1 = tl.arange(0, XBLOCK_SUB)
        loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB
        for loop1 in range(loops1):
            x0 = offset + (loop1 * XBLOCK_SUB) + base1
            # ERROR: 构造非法读异常
            tmp0 = tl.load(in_ptr0 + (x0) + 100, None)
            tmp1 = tl.load(in_ptr1 + (x0), None)
  4. 使用msSanitizer检测工具拉起Triton算子。具体参数说明请参考表2表3,内存检测请参考内存检测
    mssanitizer -t memcheck -- python sample.py

内存异常报告解析

根据检测工具输出的报告,可以发现在sample.py的18行对GM存在368字节的非法读操作,与构造的异常场景一致。
1
2
3
4
5
6
7
8
$ mssanitizer -t memcheck -- python sample.py
[mssanitizer] logging to file: ./mindstudio_sanitizer_log/mssanitizer_20250522093805_922880.log
Failed
====== ERROR: illegal read of size 368
======    at 0x12c0c0053190 on GM in triton_add
======    in block aiv(1) on device 0
======    code in pc current 0x1b0 (serialNo:524)
======    #0 sample.py:18:45