Tensorflow进给初始RNN状态

时间:2016-12-02 12:18:42

标签: tensorflow

我遇到了以下示例,但我不知道可以按如下方式提供RNN状态。

self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)
self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)

在这段代码中,初始状态被声明为归零状态。据我所知,这不是占位符。它只是一个零张量的小册子。

然后在使用RNN模型生成初始状态的函数中,在session.run中输入。

def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
    state = sess.run(self.cell.zero_state(1, tf.float32))
    for char in prime[:-1]:
        x = np.zeros((1, 1))
        x[0, 0] = vocab[char]
        feed = {self.input_data: x, self.initial_state:state}
        [state] = sess.run([self.final_state], feed)

由于self.initial_state不是占位符,如何才能获得win session.run?

以下是我正在查看的代码的link

2 个答案:

答案 0 :(得分:2)

请注意,您可以输入任何变量,而不仅仅是占位符。因此,在这种情况下,您可以手动输入元组的每个组件:

feed = {
    self.input_data: x, 
    self.initial_state[0]: state[0], 
    self.initial_state[1]: state[1]
}

答案 1 :(得分:0)

在阅读类似的RNN代码时,我遇到了与您相同的问题。

据我了解,rnn_cell.zero_state实际上会返回给您一个张量的元组,它们是可进给的。您的占位符也是张量。

因此,如果您这样做:

print(init_state[0])

# You will get something like 
<tf.Tensor 'LSTM_cell/initial_state/BasicLSTMCellZeroState/zeros:0' shape=(50, 10) dtype=float32>

feed dict允许您只进给张量或张量数组。