昇腾社区首页
中文
注册

sample_actlass_basic_matmul.py文件

import numpy as np
import mskpp
from mskpp.actlass import MatmulCoord, RowMajor


def get_kernel():
    src_path = "basic_matmul.cpp" # kernel实现文件
    kernel_name = "BasicMatmul" # 需调用的kernel名
    build_script = "make.sh" # kernel编译脚本
    config = mskpp.ActlassConfig(src_path, kernel_name)
    gen_file = mskpp.Launcher(config).code_gen()
    kernel = mskpp.compile(build_script=build_script, launch_src_file=gen_file)
    return kernel


@mskpp.autotune(configs=[
    {'L1TileShape': 'MatmulShape<64, 64, 64>', 'L0Shape': 'MatmulShape<128, 256, 64>'},
    {'L1TileShape': 'MatmulShape<64, 64, 128>', 'L0Shape': 'MatmulShape<128, 256, 64>'},
    {'L1TileShape': 'MatmulShape<64, 128, 128>', 'L0Shape': 'MatmulShape<128, 256, 64>'},
    {'L1TileShape': 'MatmulShape<64, 128, 128>', 'L0Shape': 'MatmulShape<64, 256, 64>'},
    {'L1TileShape': 'MatmulShape<128, 128, 128>', 'L0Shape': 'MatmulShape<128, 256, 64>'},
], warmup=500, repeat=10, device_ids=[0])
def basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c):
    kernel = get_kernel()
    blockdim = 20
    return kernel[blockdim](problem_shape, a, layout_a, b, layout_b, c, layout_c, device_id=1) # 算子<<<>>>调用


def data_compare(a, b, c):
    golden = np.matmul(a, b)
    is_equal = np.array_equal(c, golden)
    print("compare result: {}".format(is_equal))


if __name__ == "__main__":
    m = 1024
    n = 768
    k = 1024
    problem_shape = MatmulCoord(m, n, k)
    layout_a = RowMajor(m, k)
    layout_b = RowMajor(k, n)
    layout_c = RowMajor(m, n)
    a = np.random.randint(1, 2, [m, k]).astype(np.half)
    b = np.random.randint(1, 2, [k, n]).astype(np.half)
    c = np.zeros([m, n]).astype(np.half)
    basic_matmul(problem_shape, a, layout_a, b, layout_b, c, layout_c)
    data_compare(a, b, c)