通用矩阵乘,计算表达式如下:
A' = transpose(A) if transA else A
B' = transpose(B) if transB else B
Y = alpha * A' * B' + beta * C。
Tensor A'与Tensor B'的shape后两维(经过对应转置)需要满足矩阵乘(M, K) * (K, N) = (M, N),支持多维度。
当alpha=1.0且beta=1.0时,计算结果与矩阵乘Matmul计算接口结果相同。
gemm(tensor_a, tensor_b, para_dict)
输入tensor支持的数据类型:
Atlas 200/300/500 推理产品:支持的数据类型有float16,float32,int8,uint8,int32。其中int8,uint8,int32会被转换为float16。
Atlas 训练系列产品:支持的数据类型有float16,float32,int8,uint8,int32。其中int8,uint8,int32会被转换为float16。
Atlas 推理系列产品:支持的数据类型有float16,float32,int8,uint8。其中int8,uint8会被转换为float16。
Atlas A2训练系列产品:支持的数据类型有float16,float32,int8,uint8,int32。其中int8,uint8,int32会被转换为float16。
y:根据关系运算计算后得到的tensor,tvm.tensor类型。
此接口暂不支持与其他TBE DSL计算接口混合使用。
Atlas 200/300/500 推理产品
Atlas 训练系列产品
Atlas 推理系列产品
Atlas A2训练系列产品
from tbe import tvm from tbe import dsl a_shape = (1024, 256) b_shape = (256, 512) bias_shape = (512, ) in_dtype = "float16" dst_dtype = "float32" tensor_a = tvm.placeholder(a_shape, name='tensor_a', dtype=in_dtype) tensor_b = tvm.placeholder(b_shape, name='tensor_b', dtype=in_dtype) tensor_bias = tvm.placeholder(bias_shape, name='tensor_bias', dtype=dst_dtype) para_dict = { "tensor_bias":tensor_bias, "dst_dtype": dst_dtype } res = dsl.gemm(tensor_a, tensor_b, para_dict)