TL; DR:为什么我们不能使用类似tf.saver.Save(graph=graph_obj)
的东西来定义一个保护对象?
标题大部分都说明了...... AFAIK,为了将保护对象链接到你的图表,你需要像这样定义
with tf.Graph().as_default() as g_def:
x_input_fun = tf.placeholder(dtype=tf.float32, name='input')
y_output_fun = tf.placeholder(dtype=tf.float32, name='output')
w_weights_fun = tf.get_variable('weight_set', dtype=tf.float32, shape=(5,5))
output = tf.matmul(x_input_fun, w_weights_fun, name='pred')
loss = tf.subtract(output, y_output_fun, name='loss')
self.opti = tf.train.AdamOptimizer(loss, name='opti')
g_def.add_to_collection(tf.GraphKeys.TRAIN_OP, self.opti)
# Now the saver is linked to this graph when we do saver.save(...)
saver = tf.train.Saver()
如果你想将它链接到默认图表,你只需要说tf.train.Saver()
(如果你当中有可训练/可保存的变量)。
但为什么我们不能这样做:tf.train.Saver(graph=g_def)
?
这对我来说会更自然。当我们从检查点恢复模型时,类似的情况(对我来说)...即使我们执行以下代码
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph('some_meta_file.meta')
saver.restore(sess, './some_meta_file')
然后tf.default_graph()
仍然从导入的元文件中获取了节点。我能想出它是如何运作的原因......但现在为什么呢?
编辑:
我在检查导入图的节点时犯的错误如下。我运行了这段代码
with tf.Session(graph=tf.Graph()) as sess:
saver = tf.train.import_meta_graph('some_meta_file.meta')
saver.restore(sess, './some_meta_file')
print(sess.graph == tf.get_default_graph())
因为我想确保默认图表不包含我刚刚导入会话图表的节点。但是,这个tf.get_default_graph()当然是.. relative ..因此在这个会话中,默认图形实际上是导入的图形。
所以这也使得saver-object的工作更合乎逻辑。由于此对象将始终保存/获取tf.get_default_graph()的内容。
答案 0 :(得分:1)
为了保存或恢复任何内容,tf.train.Saver
需要一个会话,会话绑定到特定的图形实例(如示例中所示)。这意味着没有会话,保护程序实际上毫无意义。我想这是没有在保护程序中使用显式图形绑定的主要动机。
我认为您可能感兴趣的是tf.train.Saver
中的defer_build
属性:
defer_build
:如果True
,请将保存和恢复操作添加到build()
调用。在这种情况下,应在最终确定图表或使用保护程序之前调用build()
。
通过这种方式,您可以创建未绑定到任何图表的tf.train.Saver
,并稍后针对特定build()
实例调用tf.Graph
。