Tensorflow:会话结束后如何重用LSTM单元格?

时间:2017-12-13 21:16:02

标签: tensorflow neural-network lstm

我使用BasicLSTMCell预测了以下代码。

def predict(self, x, W, b):
    x = tf.reshape(x, [-1, n_input])
    x = tf.split(x,n_input,1)
    rnn_cell = rnn.BasicLSTMCell(n_hidden)
    outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32)
    return tf.matmul(outputs[-1], W) + b

我的培训阶段为我提供了W_optb_opt(除了BasicLSTMCell中隐藏的参数,我想用它来预测新数据。

但是,测试部分处于单独的功能中,并且训练阶段的session已经关闭。

如果我尝试使用

y_tf = self.predict(np.array(cur_X), self.W_opt, self.b_opt)
y = session.run(y_tf)

使用上述predict方法,我收到以下错误:

ValueError: Variable rnn/basic_lstm_cell/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? 

那么如何针对这种特殊情况更好地修复此错误?我已经看到了关于重用BasicLSTMCell的可能性的其他一些质量保证。在这种特殊情况下它是否合适,或者最好在方法之间共享session而不关闭它,或者有更好的方法?

(我是新来的,所以,虽然我认为我能以某种方式使它工作,我想看看最合适的方法是什么。谢谢。)

更新: 我试图在我用于训练阶段的with tf.Session() as session下使用相同的代码进行预测,但我得到同样的错误。 如果我设置reuse=True标志,我会收到以下错误:

ValueError: Variable rnn/basic_lstm_cell/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

如果我设置reuse=tf.get_variable()标志,则会出现以下错误:

ValueError: Trying to share variable rnn/basic_lstm_cell/kernel, but specified dtype float64 and found dtype float32_ref.

我不知道此错误的含义,以及为什么我无法在同一会话中重复使用BasicLSTMCell

0 个答案:

没有答案