昇腾社区首页
中文
注册
开发者
下载

背景及原理介绍

背景介绍

当前框架中算子输入输出shape中的动态维度通常使用-1表示,但这种方法能表达的信息有限。因此,使用符号(如s0, s1等)来表示未知维度,这些符号相较于-1可以表达更多的信息。例如下图所示,当对Data0,Data1,Data2的输入shape进行符号化之后,Add的输出shape也可以用符号表达,从而其下一个节点Mul的输出shape可以被计算出来。与shape为-1相比,此时Mul的输出shape能表达的信息更多。

但是当前的符号化表达在以下两种场景可能存在问题:

  • 需要对符号shape做Assert校验的场景

    部分算子在进行符号推导时,存在需要对某些维度进行强校验的场景,比如下图中的Matmul算子,该场景下,Matmul算子的输入符号s1必须与s2相等,否则Matmul算子的执行将会出现未定义行为(如越界或者报错)。

  • Broadcast场景

    部分算子的计算需要做Broadcast(例如Add、Mul等),比如:

    该场景下,由于s1与s3均是符号,在符号推导过程中无法确定s1与s3之间的大小关系,且仅以下三种情况满足Broadcast。另外,在编译过程中无法确定具体应执行哪一条分支,如果执行时不符合下述的分支条件,可能会出现未定义行为(如越界或者报错)。

为应对上述两个问题,引入了“Guard”这一概念。通过Guard,我们可以明确表达符号间的直接关系,比如上面的Broadcast场景,可以使用符号关系表达式s1 == s3来表示当前Broadcast路径为s1 == s3,后续的符号推导均基于此路径进行。在执行过程时,如果符号代入的值不满足s1==s3,则当前的模型编译结果将无法直接使用,需要重新编译模型。这便是Guard的使用场景之一,Guard的具体实现与原理请参见Guard功能介绍

Hint值

在了解Guard原理之前,先了解下什么是hint值:hint值是在线执行时用户传入的一个提示值,通常是当前迭代执行时用户传入的输入shape(此shape肯定是已知的),例如以下模型:

模型中Data0、Data1和Data2的部分维度未知,但在迭代执行时,用户会传入具有静态shape的输入Tensor,这些shape可在编译时作为符号Guard的参考。例如,在编译上图时,Data0的shape为[2, 3, 4, 2],Data1的shape为[2, 3, 1, 2],Data2的shape为[2, 3, 1, 2]。因此,在编译模型时,s0的hint值为2,s1的hint值为3,s2的hint值为4,s3的hint值为2,s4的hint值为1。通常,每个符号都有一个hint值,而Guard机制正是基于这些hint值实现的。Guard根据功能的不同主要分为expect_guard和assert_guard:

  • expect_guard:主要表达的是符号需要满足何种关系当前模型才能执行,主要用于分支校验,当此Guard条件不满足时,模型会触发重编译。
  • assert_guard:主要是用于符号强校验,如果符号不满足此Guard,即使重新编译模型,也完全无法执行模型,流程中断报错。

Guard基础原理

Guard机制是基于符号关系表达式实现的(例如,s0 == s1),主要用于表达两个符号运算表达式之间的大小关系。Guard是图级别的,一旦在图编译过程中生成了这样的Guard,那么在执行图编译结果时,传入的符号值必须满足该表达式,否则将导致执行失败。

  1. Guard如何生成?

    这里就需要用到前面提到的hint值了,由于符号本身的值是未知的,无法直接比较大小,但每个符号都存在一个hint值。在编译时,当需要判断一个符号关系表达式是否成立时,我们可以通过hint值来进行计算。例如:

    需要验证s0 + s1是否等于s2。直接比较无法确定s0 + s1是否等于s2,但如果将hint值代入表达式中,则可以进行计算。例如,如果s0的hint值是2,s1的hint值是3,s2的hint值是5,那么s0 + s1等于s2,可以生成一个s0 + s1 = = s2的Guard。相反,如果s0的hint值变为3,那么s0 + s1不等于s2,需要生成一个s0 + s1 != s2的Guard。这些Guard将被存储在图编译结果中,其生命周期与图一致。

  2. Guard生成之后,在执行阶段如何使用?

    假如在编译阶段生成了一个s0 + s1 == s2的Guard,需要注意的是,此Guard是根据编译时的迭代输入值生成的。也就是说,当模型是动态shape时,后续迭代的符号值会发生变化。例如,执行一个模型,此模型的输入为Data0:-1,-1; Data1:-1,-1。将此模型的输入shape符号化后,变为Data0:s0,s1; Data1:s2,s3,当在线执行此模型时:

    1. step0:需要对齐做图编译,此时的执行输入shape为Data0:2,3; Data1:5,2。因此,在图编译时,s0的hint值为2,s1的hint值为3,s2的hint值为5,s3的hint值为2。此时,在编译状态下,基于当前step的符号hint值生成一个s0+s1 == s2的Guard,并将其保存在图编译结果中。
    2. step1:如果执行输入shape不变,那么step0生成的Guard肯定能通过校验,step0的图编译结果可以直接执行。
    3. step2:假如输入shape变为Data0:3,4; Data1:7,2,此时符号的值s0变为3,s1变为4,s2变为7,尽管符号的值发生了变化,但s0+s1 == s2的Guard仍然成立,因此step0的图编译结果依然可以直接执行。
    4. step3:假如输入shape变为Data0:5,4; Data1:7,2,此时符号的值s0变为5,s0+s1 == s2的Guard为false,不满足此Guard,因此step0的图编译结果就无法再使用,此时需要基于当前step的输入值对模型进行重编译。

Guard使用场景

了解了Guard是什么之后,那么在哪些情况下需要使用Guard呢?目前,Guard主要有以下几种使用场景:

  • Broadcast场景

    Guard最常见的使用场景是Broadcast场景。在进行Broadcast时,仅凭符号本身无法判断是否需要执行Broadcast,因此需要利用Guard的功能来做出判断,例如在以下Add算子的计算场景中:

    当前Add的输出可能出现三种情况:

    • 当s0==s1时,Add算子的输出shape为[s0, 2]
    • 当s0==1时,Add算子的输出shape为[s1, 2](1的维度需要做Broadcast)
    • 当s1==1时,Add算子的输出shape为[s0, 2]

    因此,可以基于s0跟s1的hint值来生成Guard,并通过生成的Guard来决定Add的输出shape是多少,比如:

    • 当s0的hint值为2,s1的hint值也为2(即s0的hint值与s1的hint值相等)时,Add的符号推导函数需要生成一个s0==s1的expect_guard,且Add的输出shape为[s0, 2]。
    • 当s0的hint值为1,s1的hint值为2时,Add的符号推导函数需要生成一个s0==1的expect_guard,且Add的输出shape为[s1, 2]。
    • 当s0的hint值为2,s1的hint值为1时,Add的符号推导函数需要生成一个s1==1的expect_guard,且Add的输出shape为[s0, 2]。

    并且在后续迭代执行时,如果输入的shape发生了变化,导致当前的Guard不满足时,模型便需要重新编译。

  • shape强校验场景

    有一些算子对输入的shape有一定的约束要求,此时就需要使用assert_guard,例如下图的Matmul算子:该算子计算过程中,第0个输入的第1维需要与第1个输入的第0维相等,否则该算子将无法执行计算。因此,输入符号s1必须等于s2。在编译阶段,将基于s1和s2的hint值生成一个s1==s2的assert_guard。如果s1和s2的hint值不满足此Guard条件,则程序会直接报错并中断退出。

  • 整除判断

    除了上述两种使用场景之外,部分算子,比如Reshape算子还需要对shape是否支持整除进行校验,例如下图所示场景:Reshape的一个输入shape为[s0, s1],另一输入为shape为[2],值为[2, -1]的Const,此时如果Reshape能计算需要满足Mod(s0*s1,2) == 0的assert_guard,否则此Reshape无法计算。

  • 边界判断

    当某些算子需要使符号值处于某一范围内时,也可以使用Guard来控制,比如下图中的Slice算子:此Slice算子的输入x的shape为[s0, s1],offset的值为[2],size的值为2,由于offset的值需要处于[0, s0]区间内,故需要生成2<= s0的assert_guard。

Guard衍生能力

在符号推导生成Guard后,基于Guard还可以提供一些额外的功能:

  • 符号化简

    在具备上述Guard能力后,可以基于Guard能力对符号进行化简,例如,当存在一个s2==s0+s1的Guard时,可以使用s0+s1代替s2。此外,还支持符号等价并查集的功能,例如,存在s2==s0+s1和s0==2两个Guard,可以将s2化简为2 + s1。

  • StaticCheck功能

    在编译阶段,需要基于现有的符号Guard关系进行一些符号的静态比较。因此,符号化提供了StaticCheck功能。当调用static_check接口时,会执行两步操作:

    1. 首先,判断该关系表达式是否为常量表达式,如果是,则直接返回判断结果。
    2. 其次,如果表达式不是常量表达式,则会检查当前是否存在该Guard,若存在则返回Guard的结果,若不存在则返回false。