昇腾社区首页
中文
注册

ScopeDynamicLSTMPass

功能说明

将tf.nn.dynamic_rnn或tf.nn.bidirectional_dynamic_rnn生成的Scope内的小算子组合融合为DynamicLSTM算子。

Scope详情

dynamic_rnn对应的Scope结构:

或者bidirectional_dynamic_rnn中对应的两个dynamic_rnn,分别为FW和BW:

融合后的算子原型

DynamicLSTM,具体请参见AOL算子加速库接口参考

融合对应关系

当time_major为False时:

  • rnn/transpose节点的第1个输入作为融合后的第1个输入x。
  • rnn/while/basic_lstm_cell/MatMul/Enter节点的输入作为融合后的第2个输入w。
  • rnn/while/basic_lstm_cell/BiasAdd/Enter节点的输入作为融合后的第3个输入b。
  • rnn/transpose_1节点的输出作为融合后的输出output_h。

当time_major为True时:

  • rnn/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3节点的第3个输入作为融合后的第1个输入x。
  • rnn/while/basic_lstm_cell/MatMul/Enter节点的输入作为融合后的第2个输入w。
  • rnn/while/basic_lstm_cell/BiasAdd/Enter节点的输入作为融合后的第3个输入b。
  • rnn/TensorArrayStack/TensorArrayGatherV3节点的输出作为融合后的输出output_h。

上图的Scope是以time_major为True举例的。

适用网络

使用dynamic_rnn且单cell为BasicLSTMCell的推理网络。

使用约束

当前仅支持Cell结果为BasicLSTMCell的循环场景,且仅支持个别shape。

支持的shape有如下约束:

  • 输入x的shape限制:第三维的长度与输出output_h的第三维长度一致。
  • 输入w的shape限制:
    • 第一维的长度等于输入x的第二维长度与输出output_h的第二维长度相加。
    • 第二维的长度等于输出output_h的第二维长度的4倍。
  • 输入b的shape限制:第一维的长度除以16,再向上取整的结果应等于输入w的第二维的长度。

融合规则类型

定制化融合规则