REGIST_MATMUL_OBJ
功能说明
初始化Matmul对象。
函数原型
REGIST_MATMUL_OBJ(tpipe, workspace, ...)
参数说明
参数名 |
输入/输出 |
描述 |
---|---|---|
tpipe |
输入 |
Tpipe对象。 |
workspace |
输入 |
系统workspace指针。 |
... |
输入 |
可变参数,传入matmul对象和与之对应的Tiling结构,要求Tiling结构的数据类型为TCubeTiling结构。 Tiling参数可以通过host侧GetTiling接口获取,并传递到kernel侧使用。 |
参数名称 |
数据类型 |
说明 |
---|---|---|
usedCoreNum |
int |
使用的AI处理器核数,请根据实际情况设置。取值范围为:[1, AI处理器最大核数]。 |
M, N, Ka, Kb |
int |
A、B、C矩阵原始输入的shape大小,以元素为单位。M,Ka为A矩阵原始输入的Shape, Kb, N为B矩阵原始输入的Shape。
|
singleCoreM, singleCoreN, singleCoreK |
int |
A、B、C矩阵单核内shape大小,以元素为单位。 singleCoreK = K,多核处理时不对K进行切分;singleCoreM <= M;singleCoreN <= N。 注意:若A矩阵以NZ格式输入,则singleCoreM需要以16个元素对齐,singleCoreK需要以C0_size * fractal_num对齐;若B矩阵以NZ格式输入,则singleCoreK需要以C0_size * fractal_num对齐,singleCoreN需要以16个元素对齐。 half/bfloat16_t数据类型输入,C0_size为16,fractal_num为1;float数据类型输入,C0_size为8,fractal_num为2。 |
baseM, baseN, baseK |
int |
A、B、C矩阵参与一次矩阵乘指令的shape大小,以元素为单位。
需要按分形对齐。 |
depthA1, depthB1 |
int |
A、B矩阵片全载A2/B2的份数,depthA1为baseM * baseK的整数倍,depthB1为baseN * baseK的整数倍。取值大于0。 |
stepM, stepN |
int |
stepM为左矩阵在A1中缓存的bufferM方向上baseM的倍数。 stepN为右矩阵在B1中缓存的bufferN方向上baseN的倍数。 取值大于0。 |
isBias |
int |
是否使能Bias,0代表不使能Bias,1代表使能Bias。 |
transLength |
int |
max(A1Length, B1Length, C01Length)。 |
iterateOrder |
int |
一次Iterate计算出[baseM, baseN]大小的C矩阵分片,Iterate完成后,Matmul会自动偏移下一次Iterate输出的C矩阵位置,iterOrder表示自动偏移的顺序。
|
shareMode |
int |
该参数预留,开发者无需关注。 |
shareL1Size |
int |
该参数预留,开发者无需关注。 |
shareL0CSize |
int |
该参数预留,开发者无需关注。 |
shareUbSize |
int |
该参数预留,开发者无需关注。 |
batchM |
int |
该参数预留,开发者无需关注。 |
batchN |
int |
该参数预留,开发者无需关注。 |
singleBatchM |
int |
该参数预留,开发者无需关注。 |
singleBatchN |
int |
该参数预留,开发者无需关注。 |
返回值
无。
注意事项
无
调用示例
// 推荐:初始化单个matmul对象,传入tiling参数 REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm, &tiling); // 推荐:初始化多个matmul对象,传入对应的tiling参数 REGIST_MATMUL_OBJ(&pipeIn, GetSysWorkSpacePtr(), mm1, mm1tiling, mm2, mm2tiling, mm3, mm3tiling, mm4, mm4tiling); // 初始化单个matmul对象,未传入tiling参数。注意,该场景下需要使用Init接口单独传入tiling参数。这种方式将matmul对象的初始化和tiling的设置分离,比如,Tiling可变的场景,可通过这种方式多次对Tiling进行重新设置 REGIST_MATMUL_OBJ(&pipe, GetSysWorkSpacePtr(), mm); mm.Init(&tiling);