tensorflow.train.import_meta_graph不起作用?

时间:2016-08-08 12:47:09

标签: python tensorflow

我尝试简单地保存和恢复图形,但最简单的示例不能按预期工作(这是使用版本0.9.0或0.10.0在Linux 64上使用python 2.7或3.5.2而不使用CUDA)

首先我像这样保存图表:

import tensorflow as tf
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

这会创建一个文件" file"这是非空的,并且还将g1设置为看起来像正确的图形定义的东西。

然后我尝试恢复此图表:

import tensorflow as tf
g=tf.train.import_meta_graph("file")

这没有错误,但根本不返回任何内容。

任何人都可以提供必要的代码,只需保存" v4"并完全恢复它,以便在新会话中运行它将产生相同的结果?

1 个答案:

答案 0 :(得分:29)

要重复使用MetaGraphDef,您需要在原始图表中记录有趣张量的名称。例如,在第一个程序中,在namev1v2的定义中设置明确的v4参数:

v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
# ...
v4 = tf.add(v3, c1, name="v4")

然后,您可以在调用sess.run()时使用原始图表中张量的字符串名称。例如,以下代码段应该有效:

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")

sess = tf.Session()
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})

或者,您可以使用tf.get_default_graph().get_tensor_by_name()获取感兴趣的张量的tf.Tensor个对象,然后将其传递给sess.run()

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")
g = tf.get_default_graph()

v1 = g.get_tensor_by_name("v1:0")
v2 = g.get_tensor_by_name("v2:0")
v4 = g.get_tensor_by_name("v4:0")

sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})

更新:根据评论中的讨论,这里是保存和加载的完整示例,包括保存变量内容。这说明了通过在单独的操作中将变量vx的值加倍来保存变量。

保存:

import tensorflow as tf
v1 = tf.placeholder(tf.float32, name="v1") 
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")

恢复:

import tensorflow as tf
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)

最重要的是,为了使用已保存的模型,您必须记住至少一些节点的名称(例如,训练操作,输入占位符,评估张量等)。 MetaGraphDef存储模型中包含的变量列表,并有助于从检查点恢复这些变量,但您需要重建自己培训/评估模型中使用的张量/操作。