basic_matmul_autotune.py
import numpy as np
from ctypes import Structure, c_uint32, c_int32, c_int64
import mskpp
def get_kernel():
kernel_file = "./basic_matmul.cpp"
kernel_name = "BasicMatmul"
build_script = "./jit_build.sh" # kernel compile script
config = mskpp.KernelInvokeConfig(kernel_file, kernel_name)
gen_file = mskpp.Launcher(config).code_gen()
kernel = mskpp.compile(build_script=build_script, launch_src_file=gen_file)
return kernel
"""
To enable the autotune feature, it is required to add the "// tunable" marker to
the code lines in "basic_matmul.cpp", e.g.
...
51 using L1TileShape = GemmShape<128, 256, 256>; // tunable
52 using L0TileShape = GemmShape<128, 256, 64>; // tunable
"""
@mskpp.autotune(configs=[
{'L1TileShape': 'GemmShape<128, 256, 256>', 'L0TileShape': 'GemmShape<128, 256, 64>'}, #0 the same config as in basic_matmul.cpp
{'L1TileShape': 'GemmShape<128, 256, 128>', 'L0TileShape': 'GemmShape<128, 256, 64>'},
{'L1TileShape': 'GemmShape<128, 128, 256>', 'L0TileShape': 'GemmShape<128, 128, 64>'},
{'L1TileShape': 'GemmShape<64, 128, 128>', 'L0TileShape': 'GemmShape<64, 128, 128>'},
{'L1TileShape': 'GemmShape<64, 128, 256>', 'L0TileShape': 'GemmShape<64, 128, 128>'},
{'L1TileShape': 'GemmShape<64, 128, 512>', 'L0TileShape': 'GemmShape<64, 128, 128>'},
{'L1TileShape': 'GemmShape<64, 64, 128>', 'L0TileShape': 'GemmShape<64, 64, 128>'},
{'L1TileShape': 'GemmShape<64, 64, 256>', 'L0TileShape': 'GemmShape<64, 64, 128>'},
{'L1TileShape': 'GemmShape<64, 64, 512>', 'L0TileShape': 'GemmShape<64, 64, 128>'},
{'L1TileShape': 'GemmShape<128, 128, 128>', 'L0TileShape': 'GemmShape<128, 128, 128>'},
{'L1TileShape': 'GemmShape<128, 128, 256>', 'L0TileShape': 'GemmShape<128, 128, 128>'},
{'L1TileShape': 'GemmShape<128, 128, 512>', 'L0TileShape': 'GemmShape<128, 128, 128>'},
], warmup=1000, repeat=10, device_ids=[1])
def basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c):
# This function's input arguments must exactly match the kernel function.
kernel = get_kernel()
blockdim = 20 # use the correct aic number that matches your hardware
return kernel[blockdim](problem_shape, a, layout_a, b, layout_b, c, layout_c, device_id=1) # invoke the kernel
class GemmCoord(Structure):
_fields_ = [("m", c_uint32),
("n", c_uint32),
("k", c_uint32)]
def __init__(self, m, n, k):
super().__init__()
self.m = (c_uint32)(m)
self.n = (c_uint32)(n)
self.k = (c_uint32)(k)
@staticmethod
def get_namespace():
return "Catlass::"
class RowMajor(Structure):
_fields_ = [("shape", c_int32 * 2),
("stride", c_int64 * 2)]
def __init__(self, rows : int = 0, cols : int = 0, ldm : int = None):
super().__init__()
self.shape = (c_int32 * 2)(rows, cols)
if ldm is None:
self.stride = (c_int64 * 2)(cols, 1)
else:
self.stride = (c_int64 * 2)((c_int64)(ldm), 1)
@staticmethod
def get_namespace():
return "Catlass::layout::"
if __name__ == "__main__":
# prepare kernel input/output
m = 256
n = 512
k = 1024
problem_shape = GemmCoord(m, n, k)
layout_a = RowMajor(m, k)
layout_b = RowMajor(k, n)
layout_c = RowMajor(m, n)
a = np.random.randint(1, 2, [m, k]).astype(np.half)
b = np.random.randint(1, 2, [k, n]).astype(np.half)
c = np.zeros([m, n]).astype(np.half)
# invoke kernel
basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c)
# check if the output tensor c is consistent with the golden data
golden = np.matmul(a, b)
is_equal = np.array_equal(c, golden)
result = "success" if is_equal else "failed"
print("compare {}.".format(result))
父主题: 附录