如何将基于元组的tf.nn.MultiRNNCell的先前状态传递给TensorFlow中的下一个sess.run()调用?

时间:2016-09-07 17:10:35

标签: tensorflow

我正在使用使用tf.nn.MultiRNNCell构建的一堆RNN,我希望将final_state传递给下一个图调用。由于feed字典中不支持元组,所以堆叠单元格状态并切换输入以在图形的开头产生元组是实现它的唯一方法,还是TensorFlow中有一些允许这样做的功能? / p>

2 个答案:

答案 0 :(得分:4)

假设您的MultiRNNCell中有3个RNNCell,并且每个都是具有LSTMStateTuple状态的LSTMCell。您必须使用占位符复制此结构:

lstm0_c = tf.placeholder(...)
lstm0_h = tf.placeholder(...)
lstm1_c = tf.placeholder(...)
lstm1_h = tf.placeholder(...)
lstm2_c = tf.placeholder(...)
lstm2_h = tf.placeholder(...)

initial_state = tuple(
  tf.nn.rnn_cell.LSTMStateTuple(lstm0_c, lstm0_h),
  tf.nn.rnn_cell.LSTMStateTuple(lstm1_c, lstm1_h),
  tf.nn.rnn_cell.LSTMStateTuple(lstm2_c, lstm2_h))

...

sess.run(..., feed_dict={
  lstm0_c: final_state[0].c,
  lstm0_h: final_state[0].h,
  lstm1_c: final_state[1].c,
  lstm1_h: final_state[1].h,
  ...
})

如果您有N个堆叠的LSTM图层,则可以通过for循环以编程方式创建占位符和feed_dict。

答案 1 :(得分:0)

我会尝试将整个状态存储在具有以下形状的张量中:

init_state = np.zeros((num_layers, 2, batch_size, state_size))

然后喂它并将其解压缩到你的图表中

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
      [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
      for idx in range(num_layers)]
)