昇腾社区首页
中文
注册
开发者
下载

ImplGraph

在Auto Schedule处理HintGraph的过程中,会基于多种Schedule策略生成一到多份ScheduleResult,每份ScheduleResult都完整表达了算子的实现逻辑。一个ScheduleResult的整体计算会被拆分成一到多步子计算,每个子计算步骤被称为ScheduleGroup,其含义为独立应用Schedule策略。在一个ScheduleGroup中,每个Schedule策略均会产生一张新的AscGraph,该图被称为ImplGraph。

每个ImplGraph还包含Schedule策略、核间、核内切分方式、内存、向量化等。

轴变换

一个轴允许被切分为两个子轴,切分出的两根轴被称为内轴与外轴。比如将repeats为s0的z0轴切分后,产生z0Outer和z0Inner两个轴,两个轴的repeats相乘与s0相等,因此两轴的repeats可以表示为:ceil(s0/s0I)与s0I。

轴切分分为两类,Block切分与Tile切分:

  • Block切分表示将数据切分后,分配到多个核上并行执行。Block切分的外轴、内轴一般的命名方式是在原轴名字后加B/b,一个ImplGraph上,只能有一个Block外轴。比如将z0做Block切分,切分后外轴、内轴的名字分别为z0B和z0b。
  • Tile切分为普通切分,不涉及特殊语义。

两个或以上的连续循环轴可以被合并成一个,所谓连续循环轴,是指在sched.axis中连续的轴,比如,若sched.axis=[z0, z1, z2, z3],则[z1, z2, z3]为连续轴。如果轴发生了切分,例如变成[z0, z1T, z1t, z2T, z2t, z3],则[z1t, z2T]为连续轴。

如果为循环轴做过reorder,例如上述例子中,顺序变为[z0, z2, z1, z3],则[z0, z2]为连续轴。若将切分后的轴做了reorder,则不能对reorder部分的切分后轴做合并。例如上述例子中,对z1做了切分和reorder,顺序变为[z0, z1T, z2, z1t, z3],则[z2, z1t]、[z1t, z3]均不是连续轴,也就无法进行合并,在本例中,[z0, z1T]的顺序没有被reorder影响,仍然可以被合并。

在轴属性上,有如下字段表达轴的变换关系:

  • axis.type:表示本轴由哪种轴变换类型产生,如果产生一根轴经历了多次变换,那么本type仅保存最后一次的变换类型。
    • kAxisTypeOriginal:原始轴,未做过拆分或合并。
    • kAxisTypeBlockOuter:Block切分后的外轴。
    • kAxisTypeBlockInner:Block切分后的内轴。
    • kAxisTypeTileOuter:tile切分后的外轴。
    • kAxisTypeTileInner:tile切分后的内轴。
    • kAxisTypeMerged:多个轴合并后的轴。
  • from:如果本轴是被切分而来,那么from中保存被切分轴;如果本轴是由多个轴合并而来,那么from中保存合并前的轴。
  • split_pair:如果本轴是被切分而来,那么split_pair中保存同一次切分时产生的另外一根轴。

向量化和循环轴

在ImplGraph中,对轴的定义做了扩展,表达了三种含义,考虑一个API的调用:

1
2
3
4
5
6
// 共有四根轴`z0, z1, z2, z3`,长度分别为`s0, s1, s2, s3`,索引时使用`q + index`
for q0 in range(s0):
    buffer[s1 * s2 * s3];
    for q1 in range(s1):
        // inputs 包含四根轴:q0, q1, q2, q3,本次调用计算`q2, q3`两根轴的数据
        buffer[index] = CalcApi(inputs[q0][q1])

上述调用共包含如下几项信息:

  • 循环层数:q0, q1
  • 一次计算的数据长度:s2 * s3
  • 计算后,数据存储buffer的长度:例子中,buffer包含z1, z2, z3三个轴,因此长度为三个轴长度的乘积。

通过如下字段表达:

位置

属性key

含义

对应本例中的值

节点上

sched.axis

节点上的所有轴

[z0, z1, z2, z3]

节点上

sched.loop_axis

循环停止轴,停止轴及其前面的轴为循环部分,停止轴后面的部分为一次计算的数据量

z1

输出上

vectorized_axis

数据存储buffer的长度

[z1, z2, z3]

输出上

vectorized_stride

存储数据时的stride

[s2*s3, s3, 1]