昇腾社区首页
中文
注册

basic_matmul_autotune.py

import numpy as np
from ctypes import Structure, c_uint32, c_int32, c_int64
import mskpp

def get_kernel():
    kernel_file = "./basic_matmul.cpp"
    kernel_name = "BasicMatmul"
    build_script = "./jit_build.sh" # kernel compile script
    config = mskpp.KernelInvokeConfig(kernel_file, kernel_name)
    gen_file = mskpp.Launcher(config).code_gen()
    kernel = mskpp.compile(build_script=build_script, launch_src_file=gen_file)
    return kernel

"""
To enable the autotune feature, it is required to add the "// tunable" marker to
the code lines in "basic_matmul.cpp", e.g.
    ...
    51    using L1TileShape = GemmShape<128, 256, 256>; // tunable
    52    using L0TileShape = GemmShape<128, 256, 64>; // tunable
"""
@mskpp.autotune(configs=[
    {'L1TileShape': 'GemmShape<128, 256, 256>', 'L0TileShape': 'GemmShape<128, 256, 64>'}, #0 the same config as in basic_matmul.cpp
    {'L1TileShape': 'GemmShape<128, 256, 128>', 'L0TileShape': 'GemmShape<128, 256, 64>'},
    {'L1TileShape': 'GemmShape<128, 128, 256>', 'L0TileShape': 'GemmShape<128, 128, 64>'},
    {'L1TileShape': 'GemmShape<64, 128, 128>', 'L0TileShape': 'GemmShape<64, 128, 128>'},
    {'L1TileShape': 'GemmShape<64, 128, 256>', 'L0TileShape': 'GemmShape<64, 128, 128>'},
    {'L1TileShape': 'GemmShape<64, 128, 512>', 'L0TileShape': 'GemmShape<64, 128, 128>'},
    {'L1TileShape': 'GemmShape<64, 64, 128>', 'L0TileShape': 'GemmShape<64, 64, 128>'},
    {'L1TileShape': 'GemmShape<64, 64, 256>', 'L0TileShape': 'GemmShape<64, 64, 128>'},
    {'L1TileShape': 'GemmShape<64, 64, 512>', 'L0TileShape': 'GemmShape<64, 64, 128>'},
    {'L1TileShape': 'GemmShape<128, 128, 128>', 'L0TileShape': 'GemmShape<128, 128, 128>'},
    {'L1TileShape': 'GemmShape<128, 128, 256>', 'L0TileShape': 'GemmShape<128, 128, 128>'},
    {'L1TileShape': 'GemmShape<128, 128, 512>', 'L0TileShape': 'GemmShape<128, 128, 128>'},
], warmup=1000, repeat=10, device_ids=[1])
def basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c):
    # This function's input arguments must exactly match the kernel function.
    kernel = get_kernel()
    blockdim = 20 # use the correct aic number that matches your hardware
    return kernel[blockdim](problem_shape, a, layout_a, b, layout_b, c, layout_c, device_id=1) # invoke the kernel

class GemmCoord(Structure):
    _fields_ = [("m", c_uint32),
                ("n", c_uint32),
                ("k", c_uint32)]
    def __init__(self, m, n, k):
        super().__init__()
        self.m = (c_uint32)(m)
        self.n = (c_uint32)(n)
        self.k = (c_uint32)(k)
    @staticmethod
    def get_namespace():
        return "Catlass::"

class RowMajor(Structure):
    _fields_ = [("shape", c_int32 * 2),
                ("stride", c_int64 * 2)]
    def __init__(self, rows : int = 0, cols : int = 0, ldm : int = None):
        super().__init__()
        self.shape = (c_int32 * 2)(rows, cols)
        if ldm is None:
            self.stride = (c_int64 * 2)(cols, 1)
        else:
            self.stride = (c_int64 * 2)((c_int64)(ldm), 1)
    @staticmethod
    def get_namespace():
        return "Catlass::layout::"

if __name__ == "__main__":
    # prepare kernel input/output
    m = 256
    n = 512
    k = 1024
    problem_shape = GemmCoord(m, n, k)
    layout_a = RowMajor(m, k)
    layout_b = RowMajor(k, n)
    layout_c = RowMajor(m, n)
    a = np.random.randint(1, 2, [m, k]).astype(np.half)
    b = np.random.randint(1, 2, [k, n]).astype(np.half)
    c = np.zeros([m, n]).astype(np.half)

    # invoke kernel
    basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c)

    # check if the output tensor c is consistent with the golden data
    golden = np.matmul(a, b)
    is_equal = np.array_equal(c, golden)
    result = "success" if is_equal else "failed"
    print("compare {}.".format(result))