MatrixMultiplyLayer *AddMatrixMultiply(Tensor *input0, MatrixOperation type0, Tensor *input1, MatrixOperation type1) noexcept;
函数功能
将矩阵乘法层添加到网络中。
函数原型
MatrixMultiplyLayer *AddMatrixMultiply(Tensor *input0, MatrixOperation type0, Tensor *input1, MatrixOperation type1) noexcept;
约束说明
- input0和input1需为有效输入,维度范围[1,8],每一维不可为0,也不可小于-1。
- 当input0和input1为矩阵或向量时,将会计算input0与input1的内积,方法中将会校验给定输入是否满足内积运算的基本法则。(如:转置后input0的最后一维需要与input1的倒数第二维取值相同)
- 暂不支持一个输入是一维向量,另一个输入是多维向量(维度大于2)的情况,例如:[4] * [4, 5, 6] 或[6, 5, 4] * [4]。
参数说明
参数名 |
输入/输出 |
说明 |
---|---|---|
input0 |
输入 |
第一个输入张量(通常为A)。 |
type0 |
输入 |
应用于 input0 的操作。矩阵操作类型,详见enum class MatrixOperation。 |
input1 |
输入 |
第二个输入张量(通常为B)。 |
type1 |
输入 |
应用于 input1 的操作。矩阵操作类型,详见enum class MatrixOperation。 |
返回值说明
返回新的矩阵乘法层,如果添加失败则返回nullptr或抛出异常。
父主题: class Network