Checking the Triton Operator
Prerequisite
- The Triton and Triton-Ascend plug-in have been installed and configured by referring to link.
- To prevent the impact of operators that are not recompiled, you are advised to enable the following environment variables:
export TRITON_ALWAYS_COMPILE=1
- You have prepared the implementation file of the Triton operator.If you have not prepared the Triton operator, refer to the following example. This section describes the check process of the Triton operator based on this example.
# file name: sample.py import triton import triton.language as tl import torch def torch_pointwise(x0, x1): res = x0 + x1 return res @triton.jit def triton_add(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): offset = tl.program_id(0) * XBLOCK base1 = tl.arange(0, XBLOCK_SUB) loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB for loop1 in range(loops1): x0 = offset + (loop1 * XBLOCK_SUB) + base1 tmp0 = tl.load(in_ptr0 + (x0), None) tmp1 = tl.load(in_ptr1 + (x0), None) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, None) def test_case(dtype, shape, ncore, xblock, xblock_sub): x0 = torch.randn(shape, dtype=dtype).npu() x1 = torch.randn(shape, dtype=dtype).npu() y_ref = torch_pointwise(x0, x1) y_cal = torch.zeros(shape, dtype=dtype).npu() triton_add[ncore, 1, 1](x0, x1, y_cal, xblock, xblock_sub) print("Pass" if torch.equal(y_ref, y_cal) else "Failed") if __name__ == "__main__": test_case(torch.float32, (2, 4096, 8), 2, 32768, 1024)
Procedure
- Prepare for the installation by referring to Triton operator calling scenario.
- Disable the memory pool.In the sample, PyTorch is used to create tensors. In the PyTorch framework, the GM is managed in memory pool mode by default, which interferes with memory check. Therefore, you need to set the following environment variable to disable the memory pool before the check to ensure that the check result is accurate:
export PYTORCH_NO_NPU_MEMORY_CACHING=1
- Construct an illegal read/write scenario in the Triton operator by offsetting the initial load to the right by 100 elements. As a result, illegal read occurs on the GM.
def triton_add(in_ptr0, in_ptr1, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): offset = tl.program_id(0) * XBLOCK base1 = tl.arange(0, XBLOCK_SUB) loops1: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB for loop1 in range(loops1): x0 = offset + (loop1 * XBLOCK_SUB) + base1 # ERROR: Constructs an illegal read exception. tmp0 = tl.load(in_ptr0 + (x0) + 100, None) tmp1 = tl.load(in_ptr1 + (x0), None) - Use the msSanitizer tool to start the Triton operator. For details about the parameters, see Table 2 and Table 3. For details about memory check, see Memory Check.
mssanitizer -t memcheck -- python sample.py
Memory Exception Report Example
According to the report generated by the check tool, an illegal read of 368 bytes is performed on the GM in line 18 of sample.py, which is consistent with the constructed exception scenario.
1 2 3 4 5 6 7 8 | $ mssanitizer -t memcheck -- python sample.py [mssanitizer] logging to file: ./mindstudio_sanitizer_log/mssanitizer_20250522093805_922880.log Failed ====== ERROR: illegal read of size 368 ====== at 0x12c0c0053190 on GM in triton_add ====== in block aiv(1) on device 0 ====== code in pc current 0x1b0 (serialNo:524) ====== #0 sample.py:18:45 |
Parent topic: Typical Cases