昇腾社区首页
中文
注册
开发者
下载

算子Lowering实现

算子Lowering就是使用Loop IR来表达算子的计算逻辑。以基础的Add节点为例,其计算过程主要包括:

  1. 从输入Tensor中加载数据。
  2. 对输入数据进行广播。
  3. 判定计算采用的dtype,并将输入转换为对应类型。
  4. 对加载到的数据进行Add计算。
  5. 将结果写入节点的输出Tensor。

GE图编译时,使用节点Anchor表达运行时的Tensor输入,Anchor可以理解为Tensor输入的编译时占位符。使用Loop IR表达Add计算的伪码示例如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
graphStatus LoweringAdd(const ge::Node &node) {
     auto x = loop::Load(node.GetInDataAnchor(0)); // 从输入anchor0中加载数据
     auto y = loop::Load(node.GetInDataAnchor(1)); // 从输入anchor1中加载数据
     vector<Expression> broadcasted_shape = xxx; // 计算广播后的输出
     ge::DataType compute_dtype = xxx; // 计算类型提升后的计算dtype,注意,该步骤与Codegen时的主动类型提升不同
     x = loop::Broadcast(x, broadcasted_shape);
     x = loop::Cast(x, compute_dtype);
     y = loop::Broadcast(y, broadcasted_shape);
     y = loop::Cast(y, compute_dtype);
     auto result = loop::Add(x, y); // 表达计算
     loop::Store(node.GetOutDataAnchor(0), result); // 计算结果保存到输出anchor
 }

由于加载输入、保存输出、计算广播和类型提升等操作具有高度的相似性,抽取公共实现后,一个算子的Lowering可以简化为仅实现计算部分:

1
2
3
Var LoopAdd(const vector<Var> &inputs) { // Var为loop计算输出的中间结果类型
     return loop::Add(inputs[0], inputs[1]);
 }

最终LoopAdd会在Lowering过程中调用:

1
2
3
4
5
graphStatus LoweringAdd(const ge::Node &node) {
     // x, y = 公共逻辑,处理输入的加载、广播、类型转换
     auto result = LoopAdd(x, y); // 表达计算
     // 公共逻辑,处理输出的保存
 }