昇腾社区首页
中文
注册

MatrixMultiplyLayer *AddMatrixMultiply(Tensor *input0, MatrixOperation type0, Tensor *input1, MatrixOperation type1) noexcept;

函数功能

将矩阵乘法层添加到网络中。

函数原型

MatrixMultiplyLayer *AddMatrixMultiply(Tensor *input0, MatrixOperation type0, Tensor *input1, MatrixOperation type1) noexcept;

约束说明

  • input0和input1需为有效输入,维度范围[1,8],每一维不可为0,也不可小于-1。
  • 当input0和input1为矩阵或向量时,将会计算input0与input1的内积,方法中将会校验给定输入是否满足内积运算的基本法则。(如:转置后input0的最后一维需要与input1的倒数第二维取值相同)
  • 暂不支持一个输入是一维向量,另一个输入是多维向量(维度大于2)的情况,例如:[4] * [4, 5, 6] 或[6, 5, 4] * [4]。

参数说明

参数名

输入/输出

说明

input0

输入

第一个输入张量(通常为A)。

type0

输入

应用于 input0 的操作。矩阵操作类型,详见enum class MatrixOperation

input1

输入

第二个输入张量(通常为B)。

type1

输入

应用于 input1 的操作。矩阵操作类型,详见enum class MatrixOperation

返回值说明

返回新的矩阵乘法层,如果添加失败则返回nullptr或抛出异常。