gemm
Description
Performs general matrix multiplication. The formula is as follows.
A' = transpose(A) if transA else A
B' = transpose(B) if transB else B
Y = alpha x A' x B' + beta x C
The last two transposed dimensions of tensor A and tensor B should meet (M, K) x (K, N) = (M, N). Multiple dimensions are supported.
When alpha = 1.0 and beta = 1.0, the compute result is the same as that derived from Matmul API.
Prototype
gemm(tensor_a, tensor_b, para_dict)
Parameters
- tensor_a: a tvm.tensor for matrix A
- tensor_b: a tvm.tensor for matrix B
- para_dict: a dictionary of extended arguments, including:
- tensor_alpha: a tvm.tensor or float for the factor of matrix multiplication (matrix A x matrix B).
- tensor_beta: a tvm.tensor or float for matrix multiplication factor for matrix C.
- trans_a: a bool specifying whether to transpose matrix B.
- trans_b: a bool specifying whether to transpose matrix B.
- format_a: a string for the data format of input matrix A, either ND or fractal.
- format_b: a string for the data format of input matrix B, either ND or fractal.
- dst_dtype: a string for the output data type, either float16 or float32.
- tensor_c: input matrix C. Defaults to None. If the value is not empty, the API returns the addition of matrix A x matrix B + matrix C. The shape of the matrix C can be broadcast, and its data type needs to be the same as that of dst_dtype.
- format_out: a string for the format of the output tensor.
- compress_index: a tvm.tensor for the compression index of the weight matrix, which is used only when both alpha and beta are set to 1.0.
- offset_a: offset of tensor_a, which is used only when alpha = beta = 1.0.
- offset_b: offset of tensor_b, which is used only when alpha = beta = 1.0.
- kernel_name: name of the operator in the kernel (name of the generated binary file and name of the operator description file).
- quantize_params: quantization parameters. This parameter will be deprecated in a future release. Do not use it to develop any new operator.
Data types supported by the input tensor:
Returns
y: a tvm.tensor for the result tensor
Restrictions
This API cannot be used in conjunction with other TBE DSL APIs.
Applicability
Example
from tbe import tvm
from tbe import dsl
a_shape = (1024, 256)
b_shape = (256, 512)
bias_shape = (512, )
in_dtype = "float16"
dst_dtype = "float32"
tensor_a = tvm.placeholder(a_shape, name='tensor_a', dtype=in_dtype)
tensor_b = tvm.placeholder(b_shape, name='tensor_b', dtype=in_dtype)
tensor_bias = tvm.placeholder(bias_shape, name='tensor_bias', dtype=dst_dtype)
para_dict = {
"tensor_bias":tensor_bias,
"dst_dtype": dst_dtype
}
res = dsl.gemm(tensor_a, tensor_b, para_dict)