昇腾社区首页
中文
注册

Scope融合实现方案

Scope融合实现方案包括Scope融合规则注册、Scope融合规则创建和执行。

Scope融合实现方案

  • ScopePass Register:用于融合规则注册。
  • OpParser-Plugin Register:用于融合算子Parser注册,将Scope内的TensorFlow算子映射成适配昇腾AI处理器的融合算子。普通算子映射也是由该模块完成。

Scope融合规则注册

在ATC模型转换或者在TensorFlow框架内执行图时,系统会将“Ascend-cann-toolkit安装目录/ascend-toolkit/latest/opp/built-in/framework/”下的融合规则插件so(包括所有内置融合规则和自定义融合规则)自动加载到ScopePass Register。在后续Parser执行阶段可根据该融合规则的使能状态,决定是否创建并执行。

图1 Scope融合规则注册机制

Scope融合规则创建和执行

图2 Scope融合规则执行流程
  1. TensorFlow Parser在模型解析过程中, 调用ScopePassManager提供的能力, 将GraphDef里的Node信息、Scope信息用Scope融合的数据类型表示出来,从上到下的层级为ScopeGraph、ScopeTree、Scope,生成ScopeGraph。
  2. TensorFlow Parser创建并执行Scope融合,最后将最终匹配结果保存到ScopeGraph。主要包括以下几步:
    1. TensorFlow Parser根据已注册的融合规则的使能状态,创建Scope融合规则。
    2. 获取已创建的Scope融合规则,用全图的Scope,按照融合规则注册的先后顺序,逐一去匹配这些规则,如果匹配到则设置对应Scope的类型。

      如果用户自定义融合规则的名称和内置融合规则名称一样,则按照用户自定义融合规则匹配。

    3. 对上一步匹配到的Scope按Scope的连接关系等进行进一步筛选。例如上一步得到的Scope并不一定是最终目标融合的Scope,而需要筛选出有并列关系的Scope,或者筛选出有嵌套关系的Scope等。
    4. 对最终匹配到的Scope设置融合结果, 包含融合节点的名字、类型、输入、输出、描述等。
    5. 将融合结果保存到第1步生成的ScopeGraph中。
  3. TensorFlow Parser后续流程会根据融合结果信息,执行添加节点、连接节点关系、构造IR Graph等操作。

关键数据结构

ScopeGraph、ScopeTree等数据结构提供了Scope融合所需的各种能力和数据保存功能,接口定义请参见Scope融合规则开发接口,关键类的作用简介请参考表1
表1 关键类的作用简介

类名

作用简介

参考文档

Scope

Scope类型和属性定义。

Scope类

ScopeTree

存储所有Scope信息的一个树结构,可以查询Scope相关的子Scope和Node信息。

ScopeTree类

ScopeGraph

包含ScopeTree,同时定义了Scope关系匹配计算的成员函数,以及Scope融合识别的最终结果。

ScopeGraph类

ScopePattern

Scope匹配规则,目前主要包括以下三种:

  • NodeOpTypeFeature:基于scope中某一类型算子的个数或者个数的倍数匹配。
  • NodeAttrFeature:基于scope中某一类型算子的某一属性的值匹配。
  • ScopeFeature:基于scope自身或者其子scope的特征匹配。

ScopePattern类

NodeOpTypeFeature类

NodeAttrFeature类

ScopeFeature类

ScopeBaseFeature

以上三种匹配规则的基类,定义三种匹配规则的基本操作。

ScopeBaseFeature类

ScopeBasePass

用户自定义融合规则的基类,提供接口定义、通用执行流程、与通用规则匹配流程的实现。

ScopeBasePass类

ScopesResult

用于保存经过进一步匹配和筛选后最终留下来的Scope。

ScopesResult类

FusionScopesResult

保存融合结果, 包括设置融合算子名称、类型、输入、输出、描述、内部算子组合信息(多对多场景)等。

FusionScopesResult类

ScopeUtil

提供通用工具类函数。

ScopeUtil类

ScopeAttrValue

NodeAttrFeature规则定义中使用的数据结构,用于定义属性相关的规则。

ScopeAttrValue类