简介
什么是符号化
符号化是指通过符号表达算子的shape,并提供化简、推导以及Guard功能。自动融合的前提是先进行符号化,符号化之后可以保留部分算子语义,为后续循环轴合并和内存优化阶段提供关键的信息。
符号化的基础
当前自动融合为支持符号表达计算和化简功能,封装了开源C++ SymEngine库,封装这一层的优势在于保持使用的稳定性,避免与开源模块的绑定,未来可以灵活地替换为其他第三方库或自主研发的库。
符号化产生
编译期,在进入自动融合之前,首先对图上所有的输入进行符号化,产生符号。经过此阶段,图上的Data、RefData等shape输入会带有泛化后的symbol shape符号输入。具体的符号化逻辑为:如果大于0则符号化为常量符号,-1符号化为动态符号,第一个-1符号化成s0,第二个-1符号化成s1,依此类推按顺序递增。
例如,图1有三个输入,其符号化前后shape分别为:
- data0:[3, -1] -> [3, s0];
 - data1:[-1, 5] -> [s1, 5];
 - data2:[-1, -1] -> [s2, s3]。
 
符号传播
符号传播即符号推导,每个节点根据输入的符号化shape推导出输出的符号化shape。符号由Data节点产生,然后沿着图上的节点逐步推导出每个节点的符号化输出shape,直到遍历完整张图。符号传播存在如下几种情况:
- InferSymbolShape推导(一类算子处理):基于算子的输入符号shape推导输出符号化shape。
     
例如对于下图,Concat有两个输入,它们的符号化shape分别为data0[s0, s1],data1[s0, s2],concat_dim为1,根据Concat算子的功能,可以推导出其输出shape为[s0, (s1 + s2)]。

 - InferValue推导(二类算子处理):在图上做符号折叠。
     
二类算子是指编译时可以推导输出符号,但输出的推导是值依赖型,仅仅通过输入的shape符号无法推导输出的符号,除了输入符号外,还需要根据输入的值来推导输出。例如图2所示,如果要推导出BroadcastTo算子的符号化shape,不仅要知道Pack节点的输出shape,还需要知道其输出值。
由于符号是在编译期间生成的,因此对于二类算子(如ReShape等),可以考虑扩展常量折叠,通过注册相关符号折叠的kernel,在编译时实现基于符号的折叠运算,以完成基于符号的Kernel计算和InferValue逻辑。常量折叠的简要逻辑如下:
- 注册支持符号计算的算子kernel function,常见的例如Add、Sub、Mul、Shape等。
 - 参考常量折叠处理,遍历整图,如果某个节点所有的输入节点都可以获取到具体的值,调用当前node_type对应的符号kernel function进行一轮折叠计算。
 - 当前节点折叠完成后,将符号值设置到该node的对端输出Tensor上,符号折叠完成后,由于这些表达都是Host对shape的计算,后续可以在编译期优化掉这些节点,从而在执行时减少调度下发。
 - 继续做折叠计算。
 
 - 三四类算子处理:
     
三四类算子是指在编译时无法根据输入的shape或符号值推导出输出的符号,只有在真正执行完成才能确定输出的具体shape。在自动融合场景中,对于三四类算子可以打开切图编译(参考“AutoFuse使能方式中的--experimental_enable_jit_executor_v2”控制项)。
 - 自定义算子处理:
     
1. 如果未补充自定义算子符号的推导流程,则按照上述三四类算子做相同处理。
2. 如果补充了自定义算子的符号推导,一、二类算子按照正常的一、二类算子推导,三、四类同样做Fallback处理。
编译后会生成符号表达式,在执行阶段,GE需要依据符号对应的实际Dim值完成shape推导运算和内存分配。因此,在执行阶段如何根据符号进行实际shape计算是一个关键点。
 
符号化在功能实现层面分为符号推导和符号计算,优先级是先进行符号计算,再进行符号推导,如果两者都不存在则进行静态推导,顺序如下:

