昇腾社区首页
中文
注册

ScopeDynamicLSTMPass

功能说明

将tf.nn.dynamic_rnn或tf.nn.bidirectional_dynamic_rnn生成的Scope内的小算子组合融合为DynamicLSTM算子。当前仅支持Cell结果为BasicLSTMCell的循环场景,且仅支持个别shape。

Scope详情

dynamic_rnn对应的Scope结构:

或者bidirectional_dynamic_rnn中对应的两个dynamic_rnn,分别为FW和BW:

融合后的算子原型

DynamicLSTM,具体请参见算子加速库接口参考

融合对应关系

当time_major为False时:

  • rnn/transpose节点的第1个输入作为融合后的第1个输入x。
  • rnn/while/basic_lstm_cell/MatMul/Enter节点的输入作为融合后的第2个输入w。
  • rnn/while/basic_lstm_cell/BiasAdd/Enter节点的输入作为融合后的第3个输入b。
  • rnn/transpose_1节点的输出作为融合后的输出output_h。

当time_major为True时:

  • rnn/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3节点的第3个输入作为融合后的第1个输入x。
  • rnn/while/basic_lstm_cell/MatMul/Enter节点的输入作为融合后的第2个输入w。
  • rnn/while/basic_lstm_cell/BiasAdd/Enter节点的输入作为融合后的第3个输入b。
  • rnn/TensorArrayStack/TensorArrayGatherV3节点的输出作为融合后的输出output_h。

上图的Scope是以time_major为True举例的。

适用网络

使用dynamic_rnn且单cell为BasicLSTMCell的推理网络。

融合规则类型

定制化融合规则