从保存的检查点恢复Tensorflow不起作用

时间:2017-07-27 19:25:02

标签: python tensorflow

我正在训练一个简单的LSTM并按照官方文档中的建议保存检查点。完成图形和培训后,我想要检索这个训练有素的模型,并在将来对新数据进行测试。

我的图表基本上是这样的:

阻止-1

#All parameters
#################
x = tf.placeholder declaration
y = tf.placeholder declaration

weights = {'out': tf.Variable initialized from a random distribution}
biases = {'out': tf.Variable initialized from a random distribution}

pred = RNN(x, weights, biases) # Defines an LSTM cell and returns the predictions

cost = f(pred,y)
optimizer = myOptimizer.minimize(cost)

init = tf.global_variables_initializer()

block2 - 运行培训并保存检查点

with tf.Session() as sess:
    saver = tf.train.Saver()
    sess.run(init)
    while step < training_iters:
        sess.run(optimizer, feed_dict=train_feed)
        if step%save_step == 0:    
            saver.save(sess, chkpt_file,global_step=step)

    step += 1

    predictions = sess.run(pred, feed_dict=validation_feed)

我正在尝试从另一个脚本中恢复它,我正在执行block1,然后执行以下操作:

saver = tf.train.Saver()
with tf.Session() as sess:
     saver = tf.train.import_meta_graph('chckpt.meta')
     saver.restore('chckpt')

     sess.run(init)

     predictions = sess.run(pred, feed_dict=test_feed)

我没有得到预期的结果,因为我应该从检查站获得。我正在从保存的模型中恢复一些东西但看起来我的重量现在是随机的。我怀疑这是因为行sess.run(init)重置了我从检查点恢复的内容并进行了新的随机初始化。但是,当我跳过该行时,我收到以下错误: Attempting to use uninitialized value rnn/basic_lstm_cell/bias

我做错了什么?

我尝试手动恢复重量和偏差变量而不运行init,我似乎正在检索存储的值。但是要运行pred操作,我需要恢复我在RNN()函数中创建的所有变量,这很麻烦。有没有更好的方法呢?

我还尝试从我保存的同一个脚本中的检查点进行恢复,但我可以在不使用sess.run(init)的情况下逃脱。我可以在此运行pred操作并获得正确的结果。但是,它只能在同一个脚本中运行(可能是因为我在恢复脚本中遗漏了一些范围。)

0 个答案:

没有答案