在预定义的图形对象

时间:2018-01-09 13:50:35

标签: python tensorflow graph

TL; DR:为什么我们不能使用类似tf.saver.Save(graph=graph_obj)的东西来定义一个保护对象?

标题大部分都说明了...... AFAIK,为了将保护对象链接到你的图表,你需要像这样定义

with tf.Graph().as_default() as g_def:
    x_input_fun = tf.placeholder(dtype=tf.float32, name='input')
    y_output_fun = tf.placeholder(dtype=tf.float32, name='output')
    w_weights_fun = tf.get_variable('weight_set', dtype=tf.float32, shape=(5,5))
    output = tf.matmul(x_input_fun, w_weights_fun, name='pred')
    loss = tf.subtract(output, y_output_fun, name='loss')
    self.opti = tf.train.AdamOptimizer(loss, name='opti')
    g_def.add_to_collection(tf.GraphKeys.TRAIN_OP, self.opti)

    # Now the saver is linked to this graph when we do saver.save(...)
    saver = tf.train.Saver()

如果你想将它链接到默认图表,你只需要说tf.train.Saver()(如果你当中有可训练/可保存的变量)。

但为什么我们不能这样做:tf.train.Saver(graph=g_def)

这对我来说会更自然。当我们从检查点恢复模型时,类似的情况(对我来说)...即使我们执行以下代码

with tf.Session(graph=tf.Graph()) as sess:
    saver = tf.train.import_meta_graph('some_meta_file.meta')
    saver.restore(sess, './some_meta_file')

然后tf.default_graph()仍然从导入的元文件中获取了节点。我能想出它是如何运作的原因......但现在为什么呢?

编辑:

我在检查导入图的节点时犯的错误如下。我运行了这段代码

with tf.Session(graph=tf.Graph()) as sess:
    saver = tf.train.import_meta_graph('some_meta_file.meta')
    saver.restore(sess, './some_meta_file')
    print(sess.graph == tf.get_default_graph())

因为我想确保默认图表不包含我刚刚导入会话图表的节点。但是,这个tf.get_default_graph()当然是.. relative ..因此在这个会话中,默认图形实际上是导入的图形。

所以这也使得saver-object的工作更合乎逻辑。由于此对象将始终保存/获取tf.get_default_graph()的内容。

1 个答案:

答案 0 :(得分:1)

为了保存或恢复任何内容,tf.train.Saver需要一个会话,会话绑定到特定的图形实例(如示例中所示)。这意味着没有会话,保护程序实际上毫无意义。我想这是没有在保护程序中使用显式图形绑定的主要动机。

我认为您可能感兴趣的是tf.train.Saver中的defer_build属性:

  

defer_build:如果True,请将保存和恢复操作添加到build()调用。在这种情况下,应在最终确定图表或使用保护程序之前调用build()

通过这种方式,您可以创建未绑定到任何图表的tf.train.Saver,并稍后针对特定build()实例调用tf.Graph