矩阵乘,计算:tensor_c=trans_a(tensor_a) * trans_b(tensor_b) + tensor_bias。
tensor_a与tensor_b的shape后两维(经过对应转置)需要满足矩阵乘(M, K) * (K, N) = (M, N),支持多维度。如果is_fractal置为True的话,tensor_a数据排布要满足L0A的分形结构,tensor_b要满足L0B的分形结构。如果is_fractal置为False,tensor_a, tensor_b都是ND排布结构。
您可以在“te/lang/cce/te_compute/mmad_compute.py”查看接口定义。
此接口暂不支持与其他TBE DSL计算接口混合使用。
输入支持float16,输出支持float16和float32。
te.lang.cce.matmul(tensor_a, tensor_b, trans_a=False, trans_b=False, format_a="ND", format_b="ND", alpha_num=1.0, beta_num=0.0, dst_dtype="float16", tensor_bias=None, quantize_params=None, format_out=None, compress_index=None, attrs={ }, kernel_name="Matmul")
tensor_c:根据关系运算计算后得到的tensor,tvm.tensor类型。
Atlas 200/300/500 推理产品
Atlas 训练系列产品
import tvm import te.lang.cce a_shape = (1024, 256) b_shape = (256, 512) bias_shape = (512, ) in_dtype = "float16" dst_dtype = "float32" tensor_a = tvm.placeholder(a_shape, name='tensor_a', dtype=in_dtype) tensor_b = tvm.placeholder(b_shape, name='tensor_b', dtype=in_dtype) tensor_bias = tvm.placeholder(bias_shape, name='tensor_bias', dtype=dst_dtype) res = te.lang.cce.matmul(tensor_a, tensor_b, False, False, False, dst_dtype=dst_dtype, tensor_bias=tensor_bias)