ScopeDynamicLSTMPass
功能说明
将tf.nn.dynamic_rnn或tf.nn.bidirectional_dynamic_rnn生成的Scope内的小算子组合融合为DynamicLSTM算子。
Scope详情
dynamic_rnn对应的Scope结构:
或者bidirectional_dynamic_rnn中对应的两个dynamic_rnn,分别为FW和BW:
融合对应关系
当time_major为False时:
- rnn/transpose节点的第1个输入作为融合后的第1个输入x。
- rnn/while/basic_lstm_cell/MatMul/Enter节点的输入作为融合后的第2个输入w。
- rnn/while/basic_lstm_cell/BiasAdd/Enter节点的输入作为融合后的第3个输入b。
- rnn/transpose_1节点的输出作为融合后的输出output_h。
当time_major为True时:
- rnn/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3节点的第3个输入作为融合后的第1个输入x。
- rnn/while/basic_lstm_cell/MatMul/Enter节点的输入作为融合后的第2个输入w。
- rnn/while/basic_lstm_cell/BiasAdd/Enter节点的输入作为融合后的第3个输入b。
- rnn/TensorArrayStack/TensorArrayGatherV3节点的输出作为融合后的输出output_h。

上图的Scope是以time_major为True举例的。
适用网络
使用dynamic_rnn且单cell为BasicLSTMCell的推理网络。
使用约束
当前仅支持Cell结果为BasicLSTMCell的循环场景,且仅支持个别shape。
支持的shape有如下约束:
- 输入x的shape限制:第三维的长度与输出output_h的第三维长度一致。
- 输入w的shape限制:
- 第一维的长度等于输入x的第二维长度与输出output_h的第二维长度相加。
- 第二维的长度等于输出output_h的第二维长度的4倍。
- 输入b的shape限制:第一维的长度除以16,再向上取整的结果应等于输入w的第二维的长度。
融合规则类型
定制化融合规则
父主题: 融合规则说明