我是tensorflow的新手,当我尝试从已保存的变量恢复保存的参数时,我得到了“NotFoundError:检查点中找不到密钥b_1”,完整的代码如下所示。谢谢你的帮助!
import tensorflow as tf
import numpy as np
## save to a file
## need to use the same shape and dtype when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='W')
b = tf.Variable([[1,2,3]],dtype=tf.float32, name='b')
# initialization
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess, 'C:\Temp\TensorFlow\save\save.ckpt')
以下是恢复部分: ## restore variables
W = tf.Variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name='W')
b = tf.Variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name='b')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'C:\Temp\TensorFlow\save\save.ckpt')
print('weights', sess.run(W))
print('biases', sess.run(b))
答案 0 :(得分:1)
tf.train.Saver
的默认行为是使用tf.all_variables()
保存(或恢复)name
中的每个变量(除了任何其他"可保存对象")财产作为关键。
我怀疑如果您在程序的还原部分打印出W.name
和b.name
,则会获得"W_1"
和"b_1"
。为什么在指定name='W'
和name='b'
时会获得这些名称?当前TensorFlow图中必须已存在具有这些名称的变量,因此TensorFlow假定您有意创建新变量,并附加后缀("_1"
,"_2"
等)以使名字独特。例如,如果您在同一个脚本(或Jupyter笔记本)中依次运行问题中的两个代码片段,您会看到此问题。
有几种方法可以避免这个问题:
一种简单但粗暴的方法是在程序的恢复部分之前调用tf.reset_default_graph()
。这会将当前图表重置为空,因此变量最终会按照您的意图设置名称"W"
和"b"
。
您可以通过将恢复部分包装在with tf.Graph().as_default():
块中来实现类似的效果,从而确保在空图中创建变量。
您可以通过将tf.Variable
参数传递给var_list
构造函数来覆盖检查点和tf.train.Saver
对象中变量之间的映射,如下所示:
saver = tf.train.Saver(var_list={"W": W, "b", b})