matmul使能爱因斯坦乘
产品支持情况
| 
          硬件型号  | 
        
          是否支持  | 
       
|---|---|
| 
          | 
        
          √  | 
       
| 
          | 
        
          √  | 
       
| 
          | 
        
          x  | 
       
| 
          | 
        
          x  | 
       
| 
          | 
        
          x  | 
       
功能说明
矩阵乘matmul使能爱因斯坦乘。
计算公式
矩阵乘输入两个张量
和
,输出张量为
:

参数配置
| 
          成员名称  | 
        
          取值范围  | 
       
|---|---|
| 
          transposeA  | 
        
          false  | 
       
| 
          transposeB  | 
        
          false/true  | 
       
| 
          hasBias  | 
        
          false  | 
       
| 
          outDataType  | 
        
          ACL_DT_UNDEFINED  | 
       
| 
          enAccum  | 
        
          false  | 
       
| 
          matmulType  | 
        
          MATMUL_EIN_SUM  | 
       
| 
          quantMode  | 
        
          QUANT_UNDEFINED  | 
       
输入
| 
          参数  | 
        
          维度  | 
        
          数据类型  | 
        
          格式  | 
        
          描述  | 
       
|---|---|---|---|---|
| 
          x  | 
        
          [m, batch, k]  | 
        
          float16/bf16  | 
        
          ND  | 
        
          矩阵乘的A矩阵。  | 
       
| 
          weight  | 
        
          ND:[batch, k, n] NZ:[batch, n/16, k, 16]  | 
        
          float16/bf16  | 
        
          ND/NZ  | 
        
          矩阵乘的B矩阵,权重。 维度为4维时,k和n的值均为16的整数倍。  | 
       
输出
| 
          参数  | 
        
          维度  | 
        
          数据类型  | 
        
          格式  | 
        
          描述  | 
       
|---|---|---|---|---|
| 
          output  | 
        
          [m, batch, n]  | 
        
          float16/bf16  | 
        
          ND  | 
        
          矩阵乘计算结果。  | 
       
OP使用与典型场景
     OP使用时,可参考算子使用指导(C++ API)中的使用流程部分,其中,单算子构造Operation参数的构造方法参考下列各场景的参数构造部分。 
     
   // 参数构造 atb::infer::LinearParam param; param.transposeA = false; param.transposeB = false; param.hasBias = false; param.outDataType = ACL_DT_UNDEFINED; param.enAccum = false; param.matmulType = MATMUL_EIN_SUM; param.quantMode = QUANT_UNDEFINED;
# 计算示例
>>> x
tensor([[[1, 2]],
        [[3, 4]]])
>>> weight
tensor([[[1, 2, 3],
         [4, 5, 6]]])
>>> output
tensor([[[9, 12, 15],
       [19, 26, 33]]])
# 9 = 1 * 1 + 2 * 4
# 12 = 1 * 2 + 2 * 5
# 15 = 1 * 3 + 2 * 6 
# 19 = 3 * 1 + 4 * 4
# 26 = 3 * 2 + 4 * 5 
# 33 = 3 * 3 + 4 * 6
    
     父主题: 功能列表