MatmulPolicy
功能说明
定义Matmul可拓展模块策略。目前支持设置以下三种Matmul内置模板策略。
- MatmulPolicy(默认模板策略)
- TrianUpperMatmulPolicy(上三角模板策略)
一次矩阵乘指令计算的结果为baseM * baseN大小的矩阵块,称该矩阵块为基本块。若Matmul结果矩阵C中的基本块位于下三角位置,则Matmul内部做数据计算和数据搬出时,将忽略该基本块,最后得到的矩阵C为一个上三角矩阵。上三角模板策略如下图所示,图示中矩阵形状的相关大小为M=N=512,K=256,baseM=baseN=baseK=32。
图1 上三角模板策略示意图 - TrianLowerMatmulPolicy(下三角模板策略)
一次矩阵乘指令计算的结果为baseM * baseN大小的矩阵块,称该矩阵块为基本块。若Matmul结果矩阵C中的基本块位于上三角位置,则Matmul内部做数据计算和数据搬出时,将忽略该基本块,最后得到的矩阵C为一个下三角矩阵。下三角模板策略如下图所示,图示中矩阵形状的相关大小为M=N=512,K=256,baseM=baseN=baseK=32。
图2 下三角模板策略示意图
支持的型号
使用示例
默认模板策略MatmulPolicy为模板参数的默认值,下面主要介绍TrianUpperMatmulPolicy(上三角模板策略)和TrianLowerMatmulPolicy(下三角模板策略)的使用方式。
- 上三角模板策略使用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#include "lib/matmul_intf.h" typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> aType; typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> bType; typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> cType; typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> biasType; // Matmul定义时传入TrianUpperMatmulPolicy AscendC::Matmul<aType, bType, cType, biasType, CFG_NORM, MatmulCallBackFunc<nullptr, nullptr, nullptr>, AscendC::Impl::Detail::TrianUpperMatmulPolicy> mm; // 常规Matmul计算,最后输出上三角形式的结果 TPipe pipe; TCubeTiling tiling; REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling); mm.Init(&tiling); mm.SetTensorA(gmA, isTransposeA); mm.SetTensorB(gmB, isTransposeB); if (tiling.isBias) { mm.SetBias(gmBias); } mm.IterateAll(gmC);
- 下三角模板策略使用示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#include "lib/matmul_intf.h" typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> aType; typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, half> bType; typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> cType; typedef AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, float> biasType; // Matmul定义时传入TrianLowerMatmulPolicy AscendC::Matmul<aType, bType, cType, biasType, CFG_NORM, MatmulCallBackFunc<nullptr, nullptr, nullptr>, AscendC::Impl::Detail::TrianLowerMatmulPolicy> mm; // 常规Matmul计算,最后输出下三角形式的结果 TPipe pipe; TCubeTiling tiling; REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling); mm.Init(&tiling); mm.SetTensorA(gmA, isTransposeA); mm.SetTensorB(gmB, isTransposeB); if (tiling.isBias) { mm.SetBias(gmBias); } mm.IterateAll(gmC);
父主题: Matmul