我正在尝试使用tf.train.Saver()来保存和恢复张量流图。为了保存图表,我正在进行如下操作:
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
sess.run(init_op)
with sess.as_default():
for ep in range(epoch):
train_step.run(feed_dict={X: X_train,
Y: y_train.reshape(-1,1)})
saver.save(sess, 'my_test_model')
要恢复我正在做的模型:
with tf.Session() as sess:
saver.restore(sess, 'my_test_model')
test_accuracy = sess.run(x,feed_dict={X: X_valid,
Y: y_valid.reshape(-1,1)})
mse_test=loss.eval(feed_dict={X: X_valid,
Y: y_valid.reshape(-1,1)})
print(mse_test)
虽然在同一个笔记本中这样做但一切都很完美。但是当我试图在另一个笔记本中做同样的事情时,问题就开始了:在新笔记本中没有定义保护程序所以我试图再次将其定义为tf.train.Saver()并得到错误ValueError:No要保存的变量。
我也尝试了
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
但图表变量未保存在元文件中。
有人遇到过类似的问题吗?我将不胜感激。
提前致谢!