basic_matmul_autotune.py
import numpy as np from ctypes import Structure, c_uint32, c_int32, c_int64 import mskpp def get_kernel(): kernel_file = "./basic_matmul.cpp" kernel_name = "BasicMatmul" build_script = "./jit_build.sh" # kernel compile script config = mskpp.KernelInvokeConfig(kernel_file, kernel_name) gen_file = mskpp.Launcher(config).code_gen() kernel = mskpp.compile(build_script=build_script, launch_src_file=gen_file) return kernel """ To enable the autotune feature, it is required to add the "// tunable" marker to the code lines in "basic_matmul.cpp", e.g. ... 51 using L1TileShape = GemmShape<128, 256, 256>; // tunable 52 using L0TileShape = GemmShape<128, 256, 64>; // tunable """ @mskpp.autotune(configs=[ {'L1TileShape': 'GemmShape<128, 256, 256>', 'L0TileShape': 'GemmShape<128, 256, 64>'}, #0 the same config as in basic_matmul.cpp {'L1TileShape': 'GemmShape<128, 256, 128>', 'L0TileShape': 'GemmShape<128, 256, 64>'}, {'L1TileShape': 'GemmShape<128, 128, 256>', 'L0TileShape': 'GemmShape<128, 128, 64>'}, {'L1TileShape': 'GemmShape<64, 128, 128>', 'L0TileShape': 'GemmShape<64, 128, 128>'}, {'L1TileShape': 'GemmShape<64, 128, 256>', 'L0TileShape': 'GemmShape<64, 128, 128>'}, {'L1TileShape': 'GemmShape<64, 128, 512>', 'L0TileShape': 'GemmShape<64, 128, 128>'}, {'L1TileShape': 'GemmShape<64, 64, 128>', 'L0TileShape': 'GemmShape<64, 64, 128>'}, {'L1TileShape': 'GemmShape<64, 64, 256>', 'L0TileShape': 'GemmShape<64, 64, 128>'}, {'L1TileShape': 'GemmShape<64, 64, 512>', 'L0TileShape': 'GemmShape<64, 64, 128>'}, {'L1TileShape': 'GemmShape<128, 128, 128>', 'L0TileShape': 'GemmShape<128, 128, 128>'}, {'L1TileShape': 'GemmShape<128, 128, 256>', 'L0TileShape': 'GemmShape<128, 128, 128>'}, {'L1TileShape': 'GemmShape<128, 128, 512>', 'L0TileShape': 'GemmShape<128, 128, 128>'}, ], warmup=1000, repeat=10, device_ids=[1]) def basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c): # This function's input arguments must exactly match the kernel function. kernel = get_kernel() blockdim = 20 # use the correct aic number that matches your hardware return kernel[blockdim](problem_shape, a, layout_a, b, layout_b, c, layout_c, device_id=1) # invoke the kernel class GemmCoord(Structure): _fields_ = [("m", c_uint32), ("n", c_uint32), ("k", c_uint32)] def __init__(self, m, n, k): super().__init__() self.m = (c_uint32)(m) self.n = (c_uint32)(n) self.k = (c_uint32)(k) @staticmethod def get_namespace(): return "Catlass::" class RowMajor(Structure): _fields_ = [("shape", c_int32 * 2), ("stride", c_int64 * 2)] def __init__(self, rows : int = 0, cols : int = 0, ldm : int = None): super().__init__() self.shape = (c_int32 * 2)(rows, cols) if ldm is None: self.stride = (c_int64 * 2)(cols, 1) else: self.stride = (c_int64 * 2)((c_int64)(ldm), 1) @staticmethod def get_namespace(): return "Catlass::layout::" if __name__ == "__main__": # prepare kernel input/output m = 256 n = 512 k = 1024 problem_shape = GemmCoord(m, n, k) layout_a = RowMajor(m, k) layout_b = RowMajor(k, n) layout_c = RowMajor(m, n) a = np.random.randint(1, 2, [m, k]).astype(np.half) b = np.random.randint(1, 2, [k, n]).astype(np.half) c = np.zeros([m, n]).astype(np.half) # invoke kernel basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c) # check if the output tensor c is consistent with the golden data golden = np.matmul(a, b) is_equal = np.array_equal(c, golden) result = "success" if is_equal else "failed" print("compare {}.".format(result))
父主题: ·附录