我用过
tf.add_to_collection('Input', X)
tf.add_to_collection('TrueLabel', Y)
tf.add_to_collection('loss', loss)
tf.add_to_collection('accuracy', accuracy)
saver0 = tf.train.Saver()
saver0.save(sess, './save/model')
saver0.export_meta_graph('./save/model.meta')
将我的代码保存在一个会话范围内。然后,我从另一个会话范围恢复它。 CUrrent,我只有训练数据,我已经保存了占位符X和Y.此时我不能使用它们:
train_data, train_label = get_data()
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./save/model.meta')
new_saver.restore(sess, './save/model')
graph = sess.graph
X = graph.get_collection('Input')
Y = graph.get_collection('TrueLabel')
loss = graph.get_collection('loss')
accuracy = graph.get_collection('accuracy')
for _ in range(5):
loss_str, accuracy_str = sess.run([loss, accuracy], {X:train_data, Y:train_label})
print('loss:{}, accuracy:{}'.format(loss_str, accuracy_str))
我该怎么做?我发现教程文档没有提供完整的示例
答案 0 :(得分:0)
这个问题已经由我自己解决了。一旦我们加载图形和变量。只是为了获得像graph.get_tensor_by_name这样的占位符('输入:0')。使用相同的方法来获得损失和准确性等等。
可以从https://github.com/sunkevin1214/TF_implementation/blob/master/test_funs/test_save_load.py
找到完整的示例