Ascend C提供一组MatMul高阶API,方便用户快速实现MatMul矩阵乘法的运算操作。
MatMul的计算公式如下:C = A * B + Bias。
MatMul的计算示意图如下:
下文中提及的M轴方向,即为A矩阵纵向;K轴方向,即为A矩阵横向或B矩阵纵向;N轴方向,即为B矩阵横向。
实现MatMul矩阵乘运算的具体步骤如下:
创建Matmul对象的示例如下:
typedef MatmulType<TPosition::GM, CubeFormat::ND, half> aType; typedef MatmulType<TPosition::GM, CubeFormat::ND, half> bType; typedef MatmulType<TPosition::GM, CubeFormat::ND, float> cType; typedef MatmulType<TPosition::GM, CubeFormat::ND, float> biasType; Matmul<aType, bType, cType, biasType> mm;
创建对象时需要传入A、B、C、Bias的参数类型信息, 类型信息通过MatmulType来定义,包括:内存逻辑位置、数据格式、数据类型。
template <TPosition POSITION, CubeFormat FORMAT, typename TYPE> struct MatmulType { constexpr static TPosition pos = POSITION; constexpr static CubeFormat format = FORMAT; using T = TYPE; };
参数 |
说明 |
---|---|
POSITION |
内存逻辑位置
|
CubeFormat |
|
TYPE |
针对Atlas A2训练系列产品:
注意:A矩阵和B矩阵数据类型需要一致 |
mm.Init(&tiling); // 初始化
mm.SetTensorA(gm_a); // 设置左矩阵A mm.SetTensorB(gm_b); // 设置右矩阵B mm.SetBias(gm_bias); // 设置Bias
while (mm.Iterate()) { mm.GetTensorC(gm_c); }
mm.IterateAll(gm_c);
mm.End();