从.pb文件恢复图形def时出现Tensorflow错误

时间:2016-11-22 12:13:18

标签: graph initialization tensorflow restore

我正在使用tensorflow关注文本分类的wildml博客。我已经更改了代码以保存图形def,如下所示:

tf.train.write_graph(sess.graph_def,'./DeepLearn/model/','train.pb', as_text=False)

稍后在另一个文件中,我将恢复图表如下:

with tf.gfile.FastGFile(os.path.join('./DeepLearn/model/','train.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
    t = sess.graph.get_tensor_by_name('embedding/W:0')
    sess.run(t)

当我尝试运行张量并获得其值时,我收到以下错误:

tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value embedding/W

这可能是导致此错误的原因。张量应该已初始化,因为我正在从保存的图形中恢复它。

1 个答案:

答案 0 :(得分:0)

感谢Alexandre! 是的,我需要加载图形(来自.pb文件)和权重(来自检查点文件)。使用以下示例代码(取自博客),它对我有用。

with tf.Session() as persisted_sess:
    print("load graph")
    with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        persisted_sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
    persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
    tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
    try:
        saver = tf.train.Saver(tf.all_variables())
    except:pass
        print("load data")
    saver.restore(persisted_sess, "checkpoint.data")  # now OK
    print(persisted_result.eval())
    print("DONE")