矩阵乘法实现
上文介绍了Matmul矩阵乘的数据切分方案和数据流。Ascend C提供一组Matmul高阶API,封装了这些常用的切分和数据搬运、计算的算法逻辑,方便用户快速实现Matmul矩阵乘法的运算操作。开发者在host侧通过调用API自动获取Tiling参数,该参数传递到kernel侧后,在初始化操作时传入,通过几个简单的API即可完成矩阵乘操作。

kernel侧使用Matmul API矩阵乘运算的具体步骤如下:
- 创建Matmul对象
创建Matmul对象的示例如下:
typedef MatmulType<TPosition::GM, CubeFormat::ND, half> aType; typedef MatmulType<TPosition::GM, CubeFormat::ND, half> bType; typedef MatmulType<TPosition::GM, CubeFormat::ND, float> cType; typedef MatmulType<TPosition::GM, CubeFormat::ND, float> biasType; Matmul<aType, bType, cType, biasType> mm;
创建对象时需要传入A、B、C、Bias的参数类型信息, 类型信息通过MatmulType来定义,包括:内存逻辑位置、数据格式、数据类型。
- 初始化操作。
REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling); // 初始化
- 设置左矩阵A、右矩阵B、Bias。
mm.SetTensorA(gm_a); // 设置左矩阵A mm.SetTensorB(gm_b); // 设置右矩阵B mm.SetBias(gm_bias); // 设置Bias
- 完成矩阵乘操作。
- 调用Iterate完成单次迭代计算,叠加while循环完成单核全量数据的计算。Iterate方式,可以自行控制迭代次数,完成所需数据量的计算,方式比较灵活。
while (mm.Iterate()) { mm.GetTensorC(gm_c); } - 调用IterateAll完成单核上所有数据的计算。IterateAll方式,无需循环迭代,使用比较简单。
mm.IterateAll(gm_c);
- 调用Iterate完成单次迭代计算,叠加while循环完成单核全量数据的计算。Iterate方式,可以自行控制迭代次数,完成所需数据量的计算,方式比较灵活。
- 结束矩阵乘操作。
mm.End();
host侧自动获取Tiling参数的关键步骤介绍如下:
- 创建Tiling对象。
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); MatmulApiTiling cubeTiling(ascendcPlatform);
- 设置A、B、Bias的数据类型和格式。
cubeTiling.SetAType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); cubeTiling.SetBType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); cubeTiling.SetCType(TPosition::LCM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT); cubeTiling.SetBiasType(TPosition::GM, CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
- 设置矩阵shape信息。
cubeTiling.SetShape(M, N, K); cubeTiling.SetOrgShape(M, N, K);
- 设置可用空间大小信息。
cubeTiling.SetBufferSpace(-1, -1, -1);
- 按需设置其他参数,比如设置bias参与计算。
cubeTiling.SetBias(true);
- 获取Tiling参数。
MatmulCustomTilingData tiling; if (cubeTiling.GetTiling(tiling.cubeTilingData) == -1){ return ge::GRAPH_FAILED; } - Tiling参数的序列化保存等其他操作。
注意:Matmul高阶API内部实现时需要使用系统workspace,开发者需要:
- 在host侧Tiling实现时,设置总的workspace的数值大小(包含用户workspace和系统workspace),workspace空间由框架来申请并管理。系统workspace的空间大小通过GetLibApiWorkSpaceSize获取。
size_t userWorkspaceSize = 0; size_t systemWorkspaceSize = ascendcPlatform.GetLibApiWorkSpaceSize(); size_t *currentWorkspace = context->GetWorkspaceSizes(1); currentWorkspace[0] = userWorkspaceSize + systemWorkspaceSize;
- kernel侧需要在Matmul初始化前,通过SetSysWorkSpace设置系统workspace。
// 使用Matmul时必须设置workspace空间 SetSysWorkspace(workspace); if (GetSysWorkSpacePtr() == nullptr) { return; }
父主题: 矩阵编程(高阶API)