gemm

Description

Performs general matrix multiplication. The formula is as follows.

A' = transpose(A) if transA else A

B' = transpose(B) if transB else B

Y = alpha x A' x B' + beta x C

The last two transposed dimensions of tensor A and tensor B should meet (M, K) x (K, N) = (M, N). Multiple dimensions are supported.

When alpha = 1.0 and beta = 1.0, the compute result is the same as that derived from Matmul API.

Prototype

gemm(tensor_a, tensor_b, para_dict)

Parameters

  • tensor_a: a tvm.tensor for matrix A
  • tensor_b: a tvm.tensor for matrix B
  • para_dict: a dictionary of extended arguments, including:
    • tensor_alpha: a tvm.tensor or float for the factor of matrix multiplication (matrix A x matrix B).
    • tensor_beta: a tvm.tensor or float for matrix multiplication factor for matrix C.
    • trans_a: a bool specifying whether to transpose matrix B.
    • trans_b: a bool specifying whether to transpose matrix B.
    • format_a: a string for the data format of input matrix A, either ND or fractal.
    • format_b: a string for the data format of input matrix B, either ND or fractal.
    • dst_dtype: a string for the output data type, either float16 or float32.
    • tensor_c: input matrix C. Defaults to None. If the value is not empty, the API returns the addition of matrix A x matrix B + matrix C. The shape of the matrix C can be broadcast, and its data type needs to be the same as that of dst_dtype.
    • format_out: a string for the format of the output tensor.
    • compress_index: a tvm.tensor for the compression index of the weight matrix, which is used only when both alpha and beta are set to 1.0.
    • offset_a: offset of tensor_a, which is used only when alpha = beta = 1.0.
    • offset_b: offset of tensor_b, which is used only when alpha = beta = 1.0.
    • kernel_name: name of the operator in the kernel (name of the generated binary file and name of the operator description file).
    • quantize_params: quantization parameters. This parameter will be deprecated in a future release. Do not use it to develop any new operator.

Data types supported by the input tensor:

Atlas 200/300/500 Inference Product: supports float16, float32, int8, uint8, and int32. However, int8, uint8, and int32 will be cast to float16.

Atlas Training Series Product: supports float16, float32, int8, uint8, and int32. However, int8, uint8, and int32 will be cast to float16.

Returns

y: a tvm.tensor for the result tensor

Restrictions

This API cannot be used in conjunction with other TBE DSL APIs.

Applicability

Atlas 200/300/500 Inference Product

Atlas Training Series Product

Example

from tbe import tvm
from tbe import dsl
a_shape = (1024, 256)
b_shape = (256, 512)
bias_shape = (512, )
in_dtype = "float16"
dst_dtype = "float32"
tensor_a = tvm.placeholder(a_shape, name='tensor_a', dtype=in_dtype)
tensor_b = tvm.placeholder(b_shape, name='tensor_b', dtype=in_dtype)
tensor_bias = tvm.placeholder(bias_shape, name='tensor_bias', dtype=dst_dtype)
para_dict = {
    "tensor_bias":tensor_bias,
    "dst_dtype": dst_dtype
}
res = dsl.gemm(tensor_a, tensor_b, para_dict)