概念介绍
在当前算子实现模型中,完整计算逻辑分为Host部分(通常称为Tiling函数)和Device部分(通常称为Kernel)。执行算子时,首先根据输入shape等信息运行Host侧的Tiling函数,生成TilingData,并将其作为Kernel的输入,继续执行Device部分。Device侧一份Kernel代码在多个硬件单元(也被称为核、block)上并发执行,因此,每个核只负责整个算子计算的一部分。逻辑上,Tiling函数负责为当前shape选择最高效的Kernel实现,并为选定Kernel计算出合适的分核等策略,结果通过TilingData传递至Kernel,指导Kernel正确运行。
自动融合的结果是GE图中的AscBackend或FusedAscBackend(FusedAscBackend也是携带一个子图对象,子图对象类型是ComputeGraph,ComputeGraph内部有一个或者多个AscBackend节点)节点,它仅作为自动融合软件栈生成的“壳”节点,真正的计算逻辑存储在节点的AscGraph属性中。AscGraph采用DAG(有向无环图)作为基础数据结构,基于符号化能力,AscGraph可以完整表达算子的计算逻辑(包括分核策略):基于AscGraph中核内计算部分,Codegen生成单核Kernel代码;基于完整AscGraph,Auto Tiling生成Tiling函数;基于AscGraph上的符号信息,产生TilingData定义、InferShape推导函数。
AscGraph具有两种形态:
- HintGraph:融合框架的输出,作为后端输入,仅描述算子的计算逻辑。
- ImplGraph:Auto Schedule处理后的结果,作为Codegen和Auto Tiling块的输入,完整描述算子的实现逻辑,包括Schedule策略、内存管理、流水并行等所有信息。
虽然HintGraph和ScheduleResult在表达层次和用途上有所不同,它们共享相同的构图基础逻辑,并复用部分字段,同时也各自定义了一些特定字段。但是,同一个字段在不同形态下的语义保持一致,不会因形态不同而改变其含义。
AscGraph的结构
AscGraph表达多层循环中的多步计算:每一步计算对应图中的一个节点,而节点则由AscIR实例化而来,节点之间的有向边表示数据的传递关系,因此AscGraph是一张DAG。每个节点还包含属性,用于指定该计算所处的循环层级。
实例化过程指将AscIR转换为AscGraph节点的步骤,可类比为“类”到“对象”的创建过程。此过程中,AscIR的输入、输出以及属性都会被具体赋值,使节点具备可执行性。例如,若AscIR类型为Cast,则在实例化时会设置dst_dtype属性,用以指明本次类型转换要将输入数据转换为哪种数据类型。
属性是由AscIR规范所定义的字段,可存在于图或节点上,既可以是简单数值,也可以是包含更多子字段的容器(类似C语言的struct)。
节点
节点是AscIR实例化后的实体,其输入、输出及属性完全遵循AscIR的定义。 在AscGraph中,每个节点代表一次计算操作,并嵌套于循环结构中进行多次执行。
- 节点的计算语义
节点的完整计算语义由如下几部分共同组成:
- 节点类型:节点的计算类型,如Add代表加法。
- 节点属性:用于补充描述计算逻辑的必要信息。例如,Cast节点的dst_type属性指明输入数据需要转换的目标数据类型。
- 输入、输出的描述:有些节点需要输入输出的上述信息共同描述计算逻辑,例如Broadcast需要借助输入输出的repeats和axis确定在哪个轴发生了广播,Cast需要借助输出的dtype确定要将输入转换为哪种数据类型。
- 节点的属性
每个AscGraph节点都包含以下在各个形态(HintGraph和ImplGraph)中均生效的通用属性:
- 输入与输出的属性
每个节点的输入、输出都具有属性,AscGraph要求同一条边的两端,输入与输出属性值必须完全相同。因此,在多数情况下,只需关注节点的输出属性,因为相连节点的输入属性必然相同。以下是各个形态中均适用的通用属性:
- axis:该输出包含的轴数量,所有轴必须引用AscGraph全局已定义的轴。
- repeats:每个轴上的数据重复次数,即该输出在各轴上的大小。
- stride:该输出在各轴上的索引步长(stride)。
- 节点的执行方式
在AscGraph中,节点通常嵌套在多层循环中,被循环多次执行。例如,假设Foo节点的调度信息如下:
1Foo.sched.axis = [z0, z1, z2, z3];
这意味着Foo节点将在s0 * s1 * s2 * s3次迭代中被执行,即每个循环变量z0、z1、z2、z3依次遍历其对应的大小s0、s1、s2、s3,等价于如下C++ 代码:
1 2 3 4 5 6 7 8 9
for (int64_t i0 = 0; i0 < s0; ++i0) { // 遍历 `z0` 轴 for (int64_t i1 = 0; i1 < s1; ++i1) { // 遍历 `z1` 轴 for (int64_t i2 = 0; i2 < s2; ++i2) { // 遍历 `z2` 轴 for (int64_t i3 = 0; i3 < s3; ++i3) { // 遍历 `z3` 轴 // 执行 Foo 计算 } } } }
轴
轴(Axis)是AscGraph中极为重要的概念,AscGraph对循环与数据的表达均依赖轴来进行定义。轴代表数据的一个维度,并可通过name或id进行标识,其核心属性size用以描述该维度的长度。轴及其大小的符号作为属性保存在AscGraph上,是图级定义,因此轴可被图中的任意元素(如节点)通过轴的id引用,以实现对同一维度的共享认知与引用。
在图上创建轴,举例来说,创建一根长度为s0的轴:
1 2 3 4 5 6 7 | AscGraph graph; // 允许轴的长度为变量,变量名为`s0` Expression s0 = graph.CreateSizeVar("s0"); // 创建一根名字为`z0`的轴,轴的长度为 `s0`,轴的`id`在创建时分配,可以通过`z0.id`获取 Axis &z0 = graph.CreateAxis("z0", s0); |
在自动融合的命名习惯中,通常使用s前缀表示大小变量,例如s0、s1分别表示第0个、第1个大小变量。s取自symbol(符号)的首字母,表示这是一个符号化表达的变量。
轴的命名通常以z前缀,例如z0、z1,并遵循以下对应关系:z0的大小一般由s0表示,z1的大小由s1表示,依此类推。这种命名方式确保了轴(z)与其大小(s)之间的映射关系清晰可读,方便理解。
通过轴表达循环,比如,当前AscGraph有四根轴:
1 2 3 4 5 6 7 8 9 10 | AscGraph graph; Expression s0 = graph.CreateSizeVar("s0"); Expression s1 = graph.CreateSizeVar("s1"); Expression s2 = graph.CreateSizeVar("s2"); Expression s3 = graph.CreateSizeVar("s3"); Axis &z0 = graph.CreateAxis("z0", s0); Axis &z1 = graph.CreateAxis("z1", s1); Axis &z2 = graph.CreateAxis("z2", s2); Axis &z3 = graph.CreateAxis("z3", s3); |
当循环为z0, z1, z2, z3时,表达依次、全量遍历这四根轴,等价的C语言表达为:
1 2 3 4 5 6 7 8 | for (int64_t i0 = 0; i0 < s0; ++i0) { // 遍历 `z0`轴,遍历长度为`z0`轴的长度`s0` for (int64_t i1 = 0; i1 < s1; ++i1) { // 遍历 `z1`轴 for (int64_t i2 = 0; i2 < s2; ++i2) { // 遍历 `z2`轴 for (int64_t i3 = 0; i3 < s3; ++i3) { // 遍历 `z3`轴 } } } } |
图的输入和输出
AscGraph通过特定类型的节点表达输入和输出,其对应的AscIR:
- Data类型的AscIR表示图的输入。
- Data AscIR无输入,只有一个输出,对应图的某个输入。
- 其int32_t index属性用于指示该输入在图中的序号。
- 每个Data节点对应一个图输入,与其相连的节点即是读取该输入的算子。
- Output类型的AscIR表示图的输出。
- Output AscIR无输出,只有一个输入,对应图的某个输出。
- 其int index属性用于指示该输出在图中的序号。
- 每个Output节点对应一个图输出,与其相连的节点即是整图的最终输出算子。
与常规节点的输入、输出不同,Data的输出、Output的输入分别代表AscGraph的输入、输出,是AscGraph对外部的承诺,在整个后端运行过程中,不允许被修改。