当我恢复保存的图形和变量时。我怎样才能获得TF中的占位符

时间:2017-08-24 05:12:41

标签: tensorflow

我用过

 tf.add_to_collection('Input', X)
 tf.add_to_collection('TrueLabel', Y)
 tf.add_to_collection('loss', loss)
 tf.add_to_collection('accuracy', accuracy)

 saver0 = tf.train.Saver()
 saver0.save(sess, './save/model')
 saver0.export_meta_graph('./save/model.meta')

将我的代码保存在一个会话范围内。然后,我从另一个会话范围恢复它。 CUrrent,我只有训练数据,我已经保存了占位符X和Y.此时我不能使用它们:

train_data, train_label = get_data()
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./save/model.meta')
    new_saver.restore(sess, './save/model')
    graph = sess.graph
    X = graph.get_collection('Input')
    Y = graph.get_collection('TrueLabel')
    loss = graph.get_collection('loss')
    accuracy = graph.get_collection('accuracy')
    for _ in range(5):
        loss_str, accuracy_str = sess.run([loss, accuracy], {X:train_data, Y:train_label})
        print('loss:{}, accuracy:{}'.format(loss_str, accuracy_str))

我该怎么做?我发现教程文档没有提供完整的示例

1 个答案:

答案 0 :(得分:0)

这个问题已经由我自己解决了。一旦我们加载图形和变量。只是为了获得像graph.get_tensor_by_name这样的占位符('输入:0')。使用相同的方法来获得损失和准确性等等。

可以从https://github.com/sunkevin1214/TF_implementation/blob/master/test_funs/test_save_load.py

找到完整的示例