为什么TensorFlow会尝试从我没有要求的检查点恢复密钥?

时间:2018-02-11 16:46:29

标签: tensorflow

当我尝试从检查点恢复变量时,TensorFlow会查找我未指定的密钥并报告错误。

我可以使用

将变量保存在预期的键下
import tensorflow as tf
sess = tf.InteractiveSession()

raw_data = [1., 2., 8., -1., 0., 5.5, 6., 13]
spikes = tf.Variable([False] * len(raw_data), name='spikes')
spikes.initializer.run()

# After variables, listing them in a dict if not all are to be saved
saver = tf.train.Saver()

for i in range(1, len(raw_data)):
    spikes_val = spikes.eval() # Get the current values
    spikes_val[i] = True # Update new value
    updater = tf.assign(spikes, spikes_val).eval() # Assign updated values to Variable

save_path = saver.save(sess, os.path.join(os.getcwd(), '_save_eg.ckpt'))
print("spikes data saved in file: %s" % save_path)

sess.close()

并且可以通过

确认这是成功的
tf.contrib.framework.list_variables(save_path)

给出了

[('spikes', [8])]

正如所料。

但是当我尝试用

读取这个变量时
sess_in = tf.InteractiveSession()

spikes_read = tf.Variable([False] * len(raw_data), name='spikes')
tf.train.Saver().restore(sess_in, save_path)
print(spikes_read)

sess_in.close()

我得到一个NotFoundError的密钥,'spikes_1',我没有要求:

NotFoundError: Key spikes_1 not found in checkpoint [[Node: save_1/RestoreV2_1 = RestoreV2[dtypes=[DT_BOOL], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_1/tensor_names, save_1/RestoreV2_1/shape_and_slices)]]

为什么TensorFlow会尝试从我没有要求的检查点恢复密钥?

这基本上是来自Machine Learning With TensorFlow的第44页的示例,它无法正常工作,正如本书中的大部分代码一样。

1 个答案:

答案 0 :(得分:0)

你的阅读阶段是错误的。您之前已经声明了spikes变量,这意味着当前图表中存在名为spikes 的变量

当您尝试恢复模型时,您正在执行此操作:

spikes_read = tf.Variable([False] * len(raw_data), name='spikes')

这是一个名为spikes的变量的新声明:此变量已存在于当前图形中,因此Tensorflow会为您添加_1后缀以避免冲突。

在下一行:

tf.train.Saver().restore(sess_in, save_path)

您要求Saver使用当前图表来恢复save_path中的变量。 显然,这意味着保护程序不仅要查找先前声明的spikes变量,还要查找新的spikes_1变量。

您可以通过两种不同的方式解决问题:

第一种方式

如果查看tf.Saver的文档,可以看到构造函数接受要恢复的变量列表。 因此,您可以使用先前声明的变量spikes并将其作为构造函数参数传递。

所以你的阅读阶段变成了:

sess_in = tf.InteractiveSession()

# comment the `spikes_1` variable definition and just use the
# `spikes` varialble previously declared
#spikes_read = tf.Variable([False] * len(raw_data), name='spikes')
tf.train.Saver().restore(sess_in, save_path)
# or you can explicit the variable into the saver in this way, that's
# the same exact thing
# tf.train.Saver([spikes]).restore(sess_in, save_path)
print(spikes_read)

sess_in.close()

第二种方式

您可以将读取阶段包装到新的空图表中。因此,您现在可以声明一个名称为spikes的变量,该变量将由保护程序填充:

new_graph = tf.Graph()
with new_graph.as_default():
    sess_in = tf.InteractiveSession()

    spikes_read = tf.Variable([False] * len(raw_data), name='spikes')
    tf.train.Saver().restore(sess_in, save_path)
    print(spikes_read)