初始化CUDNN LSTM的状态

时间:2018-09-16 08:25:09

标签: tensorflow state cudnn

我认为我们可以使用以下代码段创建LSTM的堆栈并将其状态初始化为零。

 lstm_cell = tf.contrib.rnn.BasicLSTMCell(
            hidden_size, forget_bias=0.0, state_is_tuple=True)
 cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * num_layers, state_is_tuple=True)
 cell.zero_state(batch_size, tf_float32)

我想使用CUDNN而不是使用BasicLSTMCell

cudnn_cell = tf.contrib.cudnn_rnn.CudnnLSTM(
          num_layers, hidden_size, dropout=config.keep_prob)

在这种情况下,如何在cudnn_cell上执行与cell.zero_state(batch_size, tf_float32)相同的操作?

1 个答案:

答案 0 :(得分:0)

定义可以在tensorflow cudnn_rnn's code

中找到

关于initial_states:

with tf.Graph().as_default():
    lstm = CudnnLSTM(num_layers, num_units, direction, ...)
    outputs, output_states = lstm(inputs, initial_states, training=True)

因此,您只需要添加嵌入输入之外的初始状态。 在编码器/解码器系统中,它看起来像:

encoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size)
encoder_output, encoder_state = encoder_cell(encoder_embedding_input)
decoder_cell = tf.contrib.cudnn_rnn.CudnnLSTM(num_layers, hidden_size)
decoder_output, decoder_state = encoder_cell(decoder_embedding_input,
                                             initial_states=encoder_state)

在这里,encoder_statetuple的{​​{1}}。两种状态的形状均为(final_c_state, final_h_state)

如果您的编码器是双向RNN,那将有点棘手,因为现在输出状态变为(1, batch, hidden_size)

因此,我使用环形交叉路来解决它。

(2, batch, hidden_size)

尽管我还没有尝试过多层RNN,但我认为也可以通过类似的方式解决。