昇腾社区首页
中文
注册

MatmulOperation

功能

矩阵乘。

图1 MatmulOperation

约束

输入x / y矩阵维度,通过transposeA / transposeB配置要求满足矩阵乘的维度关系。

  • 当transposeA = true时,不接受输入维度x: [batch, m, k],y: [k, n]的情况。
  • 当输入x,y的维度都大于2时,要求两者第0维,即batch相等。

定义

struct MatmulParam {
	bool transposeA	= false;
	bool transposeB	= true;
};

成员

成员名称

描述

transposeA

是否转置A矩阵,默认不转置。

transposeB

是否转置B矩阵。默认转置。

输入

参数

维度

数据类型

格式

描述

x

当y为NDNZ格式时支持以下维度输入:

  1. x: [batch, m, k]

    y: [batch, k, n]

  2. x: [batch, m, k]

    y: [k, n]

  3. x: [m,k]

    y: [k, n]

当y为NZ格式时,额外支持以下维度输入:
  • x: [m, k]

    y: [1, n/16, k, 16]

  • x: [batch, m, k]

    y: [batch, n/16, k, 16]

float16

ND

通过transposeA、transposeB保证满足矩阵乘的维度关系。

y为NZ时,倒数第二维为16整数倍。

  1. 当y为二维或三维输入时,最后一维大小为16整数倍。
  2. 当y为四维输入时,最后一维大小为16。

y

float16

ND/NZ

输出

参数

维度

数据类型

格式

output

根据以上输入维度,输出维度为:

  1. output: [batch, m, n]
  2. output: [batch, m, n]
  3. output: [m ,n]
当y为NZ格式时,对应以下维度输出。
  1. output: [m, n]
  2. output: [batch, m, n]

float16

ND