ScopeDynamicLSTMPass

Description

Fuses small operators within the scope generated by tf.nn.dynamic_rnn or tf.nn.bidirectional_dynamic_rnn into a DynamicLSTM operator.

Scope Details

The scope structure corresponding to dynamic_rnn is as follows.

Alternatively, the two dynamic_rnn values in bidirectional_dynamic_rnn are FW and BW, respectively.

Result Operator Prototype

DynamicLSTM. For details, see Operator Library API Reference.

Fusion Mapping

When time_major is set to False:

  • Input 1 of the rnn/transpose node is used as input 1 x after fusion.
  • The input of the rnn/while/basic_lstm_cell/MatMul/Enter node is used as input 2 w after fusion.
  • The input of the rnn/while/basic_lstm_cell/BiasAdd/Enter node is used as input 3 b after fusion.
  • The output of the rnn/transpose_1 node is used as the output output_h after fusion.

When time_major is set to True:

  • Input 3 of the rnn/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3 node is used as input 1 x after fusion.
  • The input of the rnn/while/basic_lstm_cell/MatMul/Enter node is used as input 2 w after fusion.
  • The input of the rnn/while/basic_lstm_cell/BiasAdd/Enter node is used as input 3 b after fusion.
  • The output of the rnn/TensorArrayStack/TensorArrayGatherV3 node is used as the output output_h after fusion.

In the preceding scope example, time_major is set to True.

Applicable Network

Inference network that uses dynamic_rnn and a single BasicLSTMCell

Restrictions

Currently, only the loop scenario where the cell result is BasicLSTMCell is supported, and only some shapes are supported.

The supported shapes have the following restrictions:

  • Restrictions on the shape of input x: The length of the third dimension is the same as that of the third dimension of output output_h.
  • Restrictions on the shape of input w:
    • The length of the first dimension is the sum of the length of the second dimension of input x and the length of the second dimension of output output_h.
    • The length of the second dimension is four times that of the second dimension of output output_h.
  • Restrictions on the shape of input b: The length of the first dimension divided by 16 and rounded up should be equal to the length of the second dimension of input w.

Fusion Pattern Type

Non-general fusion pattern