矩阵乘matmul使能爱因斯坦乘。
矩阵乘输入两个张量和
,输出张量为
:
硬件型号 |
支持情况 |
特殊说明 |
---|---|---|
|
不支持 |
- |
|
支持 |
- |
|
支持 |
- |
成员名称 |
取值范围 |
特殊说明 |
---|---|---|
transposeA |
false |
- |
transposeB |
false/true |
- |
hasBias |
false |
- |
outDataType |
ACL_DT_UNDEFINED |
- |
enAccum |
false |
- |
matmulType |
MATMUL_EIN_SUM |
- |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
x |
[m, batch, k] |
float16 |
ND |
矩阵乘的A矩阵。 |
weight |
[batch, k, n] |
float16 |
ND |
矩阵乘的B矩阵,权重。 |
参数 |
维度 |
数据类型 |
格式 |
描述 |
---|---|---|---|---|
output |
[m, batch, n] |
float16 |
ND |
矩阵乘计算结果。 |
// 参数构造 atb::infer::LinearParam param; param.transposeA = false; param.transposeB = false; param.hasBias = false; param.outDataType = ACL_DT_UNDEFINED; param.enAccum = false; param.matmulType = MATMUL_EIN_SUM;
# 计算示例 >>> x tensor([[[1, 2]], [[3, 4]]]) >>> weight tensor([[[1, 2, 3], [4, 5, 6]]]) >>> output tensor([[[9, 12, 15], [19, 26, 33]]]) # 9 = 1 * 1 + 2 * 4 # 12 = 1 * 2 + 2 * 5 # 15 = 1 * 3 + 2 * 6 # 19 = 3 * 1 + 4 * 4 # 26 = 3 * 2 + 4 * 5 # 33 = 3 * 3 + 4 * 6