我使用BasicLSTMCell
预测了以下代码。
def predict(self, x, W, b):
x = tf.reshape(x, [-1, n_input])
x = tf.split(x,n_input,1)
rnn_cell = rnn.BasicLSTMCell(n_hidden)
outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32)
return tf.matmul(outputs[-1], W) + b
我的培训阶段为我提供了W_opt
和b_opt
(除了BasicLSTMCell
中隐藏的参数,我想用它来预测新数据。
但是,测试部分处于单独的功能中,并且训练阶段的session
已经关闭。
如果我尝试使用
y_tf = self.predict(np.array(cur_X), self.W_opt, self.b_opt)
y = session.run(y_tf)
使用上述predict
方法,我收到以下错误:
ValueError: Variable rnn/basic_lstm_cell/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
那么如何针对这种特殊情况更好地修复此错误?我已经看到了关于重用BasicLSTMCell
的可能性的其他一些质量保证。在这种特殊情况下它是否合适,或者最好在方法之间共享session
而不关闭它,或者有更好的方法?
(我是新来的,所以,虽然我认为我能以某种方式使它工作,我想看看最合适的方法是什么。谢谢。)
更新:
我试图在我用于训练阶段的with
tf.Session() as session
下使用相同的代码进行预测,但我得到同样的错误。
如果我设置reuse=True
标志,我会收到以下错误:
ValueError: Variable rnn/basic_lstm_cell/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
如果我设置reuse=tf.get_variable()
标志,则会出现以下错误:
ValueError: Trying to share variable rnn/basic_lstm_cell/kernel, but specified dtype float64 and found dtype float32_ref.
我不知道此错误的含义,以及为什么我无法在同一会话中重复使用BasicLSTMCell
。