以下是代码:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
# print(os.getcwd())
# os.chdir(os.getcwd())
# os.chdir("/tmp")
chk_file = "hello.chk"
def save(checkpoint_file=chk_file):
with tf.Session() as session:
x = tf.Variable(initial_value=[1, 2, 3], name="x")
y = tf.Variable(initial_value=[[1.0, 2.0], [3.0, 4.0]], name="y")
session.run(tf.global_variables_initializer())
print(session.run(tf.global_variables()))
saver = tf.train.Saver()
save_path = saver.save(sess=session, save_path=checkpoint_file)
print(session.run(tf.global_variables()))
def restore(checkpoint_file=chk_file):
with tf.Session() as session:
saver = tf.train.Saver()
saver.restore(sess=session, save_path=checkpoint_file)
print(session.run(tf.global_variables()[0]))
print(tf.global_variables()[0])
# print(session.run(tf.get_variable("x", shape=(3, ))))
def reset():
tf.reset_default_graph()
path = save()
# print(path)
restore("/home/kaiyin/PycharmProjects/text-classify/hello.chk")
一些问题:
restore(path)
不起作用,与saver.restore
相对路径对restore
不起作用,即使您已经在正确的目录中
如果您注释掉path = save()
行,则会收到错误消息:
/ home / kaiyin / virtualenvs / tensorflow / bin / python /home/kaiyin/PycharmProjects/text-classify/restore.py Traceback(最近一次调用最后一次): 文件" /home/kaiyin/PycharmProjects/text-classify/restore.py" ;,第38行,在 恢复(" /home/kaiyin/PycharmProjects/text-classify/hello.chk") 文件" /home/kaiyin/PycharmProjects/text-classify/restore.py",第27行,还原 saver = tf.train.Saver() 文件" /home/kaiyin/virtualenvs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py" ;,第1040行, init self.build() 文件" /home/kaiyin/virtualenvs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py" ;,第1061行,在构建中 提出ValueError("没有要保存的变量") ValueError:没有要保存的变量
流程已完成退出代码1
我可以忍受前两个问题,但第三个问题是一个真正的阻碍者。为什么我需要保存会话每次我想恢复它?由于没有全局会话对象,save
函数如何产生这样的影响也有点神秘。
Tensorflow版本:1.0.1
Python 3.5.2
Ubuntu 16.04
答案 0 :(得分:1)
我整天都很头疼。
刚刚解决了:
正确的方法是在调用restore()
之前需要重新初始化所有变量例如,在cifar10项目中(cifar10.py - 第188行)
如果您想恢复以前保存的变量
首先需要调用inference()来初始化所有变量 然后调用restore()。
<强>重新初始化强>
def restore(checkpoint_file=chk_file):
with tf.Session() as session:
x = tf.Variable(initial_value=[1, 2, 3], name="x")
y = tf.Variable(initial_value=[[1.0, 2.0], [3.0, 4.0]], name="y")
saver = tf.train.Saver()
saver.restore(sess=session, save_path=checkpoint_file)
print(session.run(tf.global_variables()[1]))
print(tf.global_variables()[0])
# print(session.run(tf.get_variable("x", shape=(3, ))))
def reset():
tf.reset_default_graph()
restore("/home/kaiyin/PycharmProjects/text-classify/hello.chk")