算子Lowering实现
算子Lowering就是使用Loop IR来表达算子的计算逻辑。以基础的Add节点为例,其计算过程主要包括:
- 从输入Tensor中加载数据。
- 对输入数据进行广播。
- 判定计算采用的dtype,并将输入转换为对应类型。
- 对加载到的数据进行Add计算。
- 将结果写入节点的输出Tensor。
GE图编译时,使用节点Anchor表达运行时的Tensor输入,Anchor可以理解为Tensor输入的编译时占位符。使用Loop IR表达Add计算的伪码示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 | graphStatus LoweringAdd(const ge::Node &node) { auto x = loop::Load(node.GetInDataAnchor(0)); // 从输入anchor0中加载数据 auto y = loop::Load(node.GetInDataAnchor(1)); // 从输入anchor1中加载数据 vector<Expression> broadcasted_shape = xxx; // 计算广播后的输出 ge::DataType compute_dtype = xxx; // 计算类型提升后的计算dtype,注意,该步骤与Codegen时的主动类型提升不同 x = loop::Broadcast(x, broadcasted_shape); x = loop::Cast(x, compute_dtype); y = loop::Broadcast(y, broadcasted_shape); y = loop::Cast(y, compute_dtype); auto result = loop::Add(x, y); // 表达计算 loop::Store(node.GetOutDataAnchor(0), result); // 计算结果保存到输出anchor } |
由于加载输入、保存输出、计算广播和类型提升等操作具有高度的相似性,抽取公共实现后,一个算子的Lowering可以简化为仅实现计算部分:
1 2 3 | Var LoopAdd(const vector<Var> &inputs) { // Var为loop计算输出的中间结果类型 return loop::Add(inputs[0], inputs[1]); } |
最终LoopAdd会在Lowering过程中调用:
1 2 3 4 5 | graphStatus LoweringAdd(const ge::Node &node) { // x, y = 公共逻辑,处理输入的加载、广播、类型转换 auto result = LoopAdd(x, y); // 表达计算 // 公共逻辑,处理输出的保存 } |
父主题: Lowering