昇腾社区首页
中文
注册

matmul使能爱因斯坦乘

产品支持情况

硬件型号

是否支持

Atlas A3 推理系列产品 / Atlas A3 训练系列产品

Atlas A2 训练系列产品 / Atlas 800I A2 推理产品

Atlas 训练系列产品

x

Atlas 推理系列产品

x

Atlas 200I/500 A2 推理产品

x

功能说明

矩阵乘matmul使能爱因斯坦乘。

计算公式

矩阵乘输入两个张量,输出张量为

参数配置

成员名称

取值范围

transposeA

false

transposeB

false/true

hasBias

false

outDataType

ACL_DT_UNDEFINED

enAccum

false

matmulType

MATMUL_EIN_SUM

quantMode

QUANT_UNDEFINED

输入

参数

维度

数据类型

格式

描述

x

[m, batch, k]

float16/bf16

ND

矩阵乘的A矩阵。

weight

ND:[batch, k, n]

NZ:[batch, n/16, k, 16]

float16/bf16

ND/NZ

矩阵乘的B矩阵,权重。

维度为4维时,k和n的值均为16的整数倍。

输出

参数

维度

数据类型

格式

描述

output

[m, batch, n]

float16/bf16

ND

矩阵乘计算结果。

OP使用与典型场景

OP使用时,可参考算子使用指导中的使用流程部分,其中,单算子构造Operation参数的构造方法参考下列各场景的参数构造部分。
// 参数构造
atb::infer::LinearParam param;
param.transposeA = false;
param.transposeB = false;
param.hasBias = false;
param.outDataType = ACL_DT_UNDEFINED;
param.enAccum = false;
param.matmulType = MATMUL_EIN_SUM;
param.quantMode = QUANT_UNDEFINED;
# 计算示例
>>> x
tensor([[[1, 2]],
        [[3, 4]]])
>>> weight
tensor([[[1, 2, 3],
         [4, 5, 6]]])
>>> output
tensor([[[9, 12, 15],
       [19, 26, 33]]])
# 9 = 1 * 1 + 2 * 4
# 12 = 1 * 2 + 2 * 5
# 15 = 1 * 3 + 2 * 6 
# 19 = 3 * 1 + 4 * 4
# 26 = 3 * 2 + 4 * 5 
# 33 = 3 * 3 + 4 * 6