在开发算子代码之前需要分析算子的数学表达式、输入、输出以及计算逻辑的实现,明确需要调用的Ascend C接口。
- 明确算子的数学表达式及计算逻辑。
Matmul算子完成矩阵乘操作,其数学表达式如下,形状为[m, k]的矩阵a和形状为[k, n]的矩阵b相乘,得到形状为[m, n]的矩阵c。为了方便,令m=k=n=32。
c = a * b
注意需要处理的数据过大时,需要对数据进行切分并分块搬运到A2、B2,分别计算后再进行汇聚。下文的计算逻辑为了展示Split和Aggregate阶段的样例,请您根据实际需要处理的数据大小决定是否需要切分和汇聚。
计算逻辑如下:
- 分别搬运输入数据矩阵a、b至Local Memory A1、B1。
- 将a矩阵从A1搬运至A2。将b矩阵切分为part1和part2,形状均为[k, n / 2],切分后再分块搬运至B2。
- a矩阵和b矩阵part1、part2分别做矩阵乘运算,获得矩阵c的part1和part2,形状均为[m, n / 2]。计算结果在CO1存储。
- 将矩阵c的part1和part2分别拷贝到CO2进行合并。
- 将合并后的输出数据从CO2搬出。
- 明确输入和输出。
- Matmul算子有两个输入:a与b,输出为c。
- 本样例中算子输入支持的数据类型为half(float16),算子输出的数据类型为float32。
- 矩阵a、b、c的形状均为[32, 32]。
- 算子输入输出支持的数据格式为:ND。
- 确定核函数名称和参数。
- 您可以自定义核函数名称,本样例中核函数命名为matmul_custom。
- 根据对算子输入输出的分析,确定核函数有3个参数a,b,c;a,b为输入在Global Memory上的内存地址,c为输出在Global Memory上的内存地址。
- 约束分析。
由于硬件架构对矩阵乘计算的输入输出有格式约束,需要在算子实现中增加格式转换的流程。
- 搬运矩阵a、b至A1、B1时,将ND格式的矩阵a、b转换为NZ格式。
- 从A1搬运矩阵a至A2时,将NZ格式的a矩阵转换为ZZ格式;从B1搬运矩阵b到B2时将NZ格式的b矩阵转换为ZN格式。
- 将计算结果从CO2搬出时,将NZ格式的c矩阵转换为ND格式。
- 确定算子实现所需接口。
- 实现外部存储和内部存储间的数据搬运,查看Ascend C API参考中的数据搬移接口,具体参考DataCopy。
- 实现矩阵数据格式转换,查看Ascend C API参考中的数据转换接口,具体参考LoadData。
- 矩阵计算过程涉及矩阵乘法,查看Ascend C API参考中的矩阵计算接口,具体参考Mmad。
- 计算中使用到的Tensor数据结构,使用Queue队列进行管理,会使用到EnQue、DeQue等接口。
通过以上分析,得到Ascend C Matmul算子的计算流程图和设计规格如下:
图1 Matmul算子的计算流程图
表1 Ascend C Matmul算子设计规格算子类型(OpType)
|
Matmul
|
算子输入
|
name
|
shape
|
data type
|
format
|
a
|
(m, k) = (32, 32)
|
half
|
ND
|
b
|
(k, n) = (32, 32)
|
half
|
ND
|
算子输出
|
c
|
(m, n) = (32, 32)
|
float32
|
ND
|
核函数名称
|
matmul_custom
|
使用的主要接口
|
DataCopy:数据搬移接口
|
LoadData:矩阵数据格式转换接口
|
Mmad:矩阵乘计算接口
|
EnQue、DeQue等接口:Queue队列管理接口
|
算子实现文件名称
|
matmul_custom.cpp
|