matmul
Description
Multiplies matrices. The formula is tensor_c = trans_a(tensor_a) x trans_b(tensor_b) + tensor_bias.
The last two transposed dimensions of tensor_a and tensor_b should meet (M, K) x (K, N) = (M, N).
Prototype
matmul(tensor_a, tensor_b, trans_a=False, trans_b=False, format_a="ND", format_b="ND", alpha_num=1.0, beta_num=1.0, dst_dtype="float16", tensor_bias=None, quantize_params=None, format_out=None, compress_index=None, attrs={ }, kernel_name="Matmul")
Parameters
- tensor_a: a tvm.tensor for matrix A
- tensor_b: a tvm.tensor for matrix B
The two Tensors support the following data types:
Atlas 200/300/500 Inference Product : supports float16, float32, and int32.Atlas Training Series Product : supports float16, float32, and int32. - trans_a: a bool specifying whether to transpose matrix A.
- trans_b: a bool specifying whether to transpose matrix B.
- format_a: format of matrix A, either ND (default), FRACTAL_NZ, or FRACTAL_Z.
- format_b: format of matrix B, either ND (default), FRACTAL_NZ, or FRACTAL_Z.
- alpha_num: an extended parameter, which is reserved currently. Defaults to 1.0.
- beta_num: an extended parameter, which is reserved currently. Defaults to 1.0.
- dst_dtype: output data type, either float16 or float32.
- tensor_bias: defaults to None. If the value is not empty, tensor_bias is added to the product of matrix A and matrix B. The shape of tensor_bias can be broadcast when its data type is the same as that of dst_dtype.
- quantize_params: quantization parameters. This parameter will be deprecated in a future release. Do not use it to develop any new operator.
- format_out: format of the result tensor, selected from ND, FRACTAL_NZ, or FRACTAL_Z.
- compress_index: compression index of the weight matrix.
- attrs: a dictionary of extended arguments.
- kernel_name: name of the operator in the kernel (name of the generated binary file and name of the operator description file).
Returns
tensor_c: a tvm.tensor for the result tensor.
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
Applicability
Example
from tbe import tvm from tbe import dsl a_shape = (1024, 256) b_shape = (256, 512) bias_shape = (512, ) in_dtype = "float16" dst_dtype = "float32" tensor_a = tvm.placeholder(a_shape, name='tensor_a', dtype=in_dtype) tensor_b = tvm.placeholder(b_shape, name='tensor_b', dtype=in_dtype) tensor_bias = tvm.placeholder(bias_shape, name='tensor_bias', dtype=dst_dtype) res = dsl.matmul(tensor_a, tensor_b, False, False, dst_dtype=dst_dtype, tensor_bias=tensor_bias)