我尝试运行一个简单的程序将Tensorflow会话保存为磁盘“spikes.cpkt”。虽然在交互式程序中,系统输出显示我已成功创建该文件,但我无法在文件系统中找到该文件。
我使用的Tensorflow版本是使用Python 2的0.11rc。操作系统是Ubuntu 16.04。该程序是在Jupiter笔记本中编写和运行的。
以下是保存会话的源代码:
# Import TensorFlow and enable interactive sessions
import tensorflow as tf
sess = tf.InteractiveSession()
# Let's say we have a series data like this
raw_data = [1., 2., 8., -1., 0., 5.5, 6., 13.]
# Define a boolean vector called `spikes` to locate a sudden spike in raw data
spikes = tf.Variable([False] * len(raw_data), name='spikes')
# Don't forget to initialize the variable
spikes.initializer.run()
# The saver op will enable saving and restoring variables.
# If no dictionary is passed into the constructor, then the saver operators of all variables in the current program.
saver = tf.train.Saver()
# Loop through the data and update the spike variable when there is a significant increase
for i in range(1, len(raw_data)):
if raw_data[i] - raw_data[i-1] > 5:
spikes_val = spikes.eval()
spikes_val[i] = True
# Update the value of spikes by using the `tf.assign` function
updater = tf.assign(spikes, spikes_val)
# Don't forget to actually evaluate the updater, otherwise spikes will not be updated
updater.eval()
# Save the variable to the disk
save_path = saver.save(sess, "spikes.ckpt")
# Print out where the relative file path of the saved variables
print("spikes data saved in file: %s" % save_path)
# Remember to close the session after it will no longer be used
sess.close()
磁盘中没有名为“spikes.ckpt”的文件。
答案 0 :(得分:8)
TensorFlow最近推出了一种新的检查点格式(Saver V2),它将检查点保存为一组带有公共前缀的文件。要创建使用旧格式的tf.train.Saver
,您可以create it,如下所示:
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1)
答案 1 :(得分:0)
您只需要将变量的名称放在tf.trai.Saver
中saver = tf.train.Saver([spikes])
答案 2 :(得分:0)
我遇到了同样的问题,我正在阅读使用Tensorflow的机器学习一书,在论坛中你也可以找到使路径相对的解决方案
save_path = saver.save(sess, "./spikes.ckpt")
答案 3 :(得分:0)
尝试使用绝对路径而不是相对路径。就我而言,它解决了这个问题。有点奇怪,因为像TensorFlow这样成熟的库应该支持相对路径。