我正在使用tensorflow关注文本分类的wildml博客。我已经更改了代码以保存图形def,如下所示:
tf.train.write_graph(sess.graph_def,'./DeepLearn/model/','train.pb', as_text=False)
稍后在另一个文件中,我将恢复图表如下:
with tf.gfile.FastGFile(os.path.join('./DeepLearn/model/','train.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
t = sess.graph.get_tensor_by_name('embedding/W:0')
sess.run(t)
当我尝试运行张量并获得其值时,我收到以下错误:
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value embedding/W
这可能是导致此错误的原因。张量应该已初始化,因为我正在从保存的图形中恢复它。
答案 0 :(得分:0)
感谢Alexandre! 是的,我需要加载图形(来自.pb文件)和权重(来自检查点文件)。使用以下示例代码(取自博客),它对我有用。
with tf.Session() as persisted_sess:
print("load graph")
with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
try:
saver = tf.train.Saver(tf.all_variables())
except:pass
print("load data")
saver.restore(persisted_sess, "checkpoint.data") # now OK
print(persisted_result.eval())
print("DONE")