无法从已保存的变量TensorFlow 1.0,python 3.5.1中恢复变量

时间:2017-03-29 12:39:33

标签: python tensorflow

我是tensorflow的新手,当我尝试从已保存的变量恢复保存的参数时,我得到了“NotFoundError:检查点中找不到密钥b_1”,完整的代码如下所示。谢谢你的帮助!

import tensorflow as tf
import numpy as np
## save to a file
## need to use the same shape and dtype when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='W')
b = tf.Variable([[1,2,3]],dtype=tf.float32, name='b') 


# initialization
init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, 'C:\Temp\TensorFlow\save\save.ckpt')

以下是恢复部分:     ## restore variables

W = tf.Variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name='W')
b = tf.Variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name='b')    


saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, 'C:\Temp\TensorFlow\save\save.ckpt')
    print('weights', sess.run(W))
    print('biases', sess.run(b))

1 个答案:

答案 0 :(得分:1)

tf.train.Saver的默认行为是使用tf.all_variables()保存(或恢复)name中的每个变量(除了任何其他"可保存对象")财产作为关键。

我怀疑如果您在程序的还原部分打印出W.nameb.name,则会获得"W_1""b_1"。为什么在指定name='W'name='b'时会获得这些名称?当前TensorFlow图中必须已存在具有这些名称的变量,因此TensorFlow假定您有意创建新变量,并附加后缀("_1""_2"等)以使名字独特。例如,如果您在同一个脚本(或Jupyter笔记本)中依次运行问题中的两个代码片段,您会看到此问题。

有几种方法可以避免这个问题:

  1. 一种简单但粗暴的方法是在程序的恢复部分之前调用tf.reset_default_graph()。这会将当前图表重置为空,因此变量最终会按照您的意图设置名称"W""b"

  2. 您可以通过将恢复部分包装在with tf.Graph().as_default():块中来实现类似的效果,从而确保在空图中创建变量。

  3. 您可以通过将tf.Variable参数传递给var_list构造函数来覆盖检查点和tf.train.Saver对象中变量之间的映射,如下所示:

    saver = tf.train.Saver(var_list={"W": W, "b", b})