即使退出with块,tf.Session()似乎也没有关闭

时间:2017-01-10 20:20:17

标签: python tensorflow

我正在保存两张图表;一个具有2X2零张量,另一个具有相同大小的张量。我会根据情况恢复其中一个。

import tensorflow as tf


def save_zero():
    # save a 2x2 variable filled with zeros
    with tf.Graph().as_default():
        session = tf.Session()
        with tf.name_scope('dummy_graph'):
            tf.Variable([[0.0, 0.0], [0.0, 0.0]], name='a', dtype=tf.float32)
        init_op = tf.global_variables_initializer()
        session.run(init_op)
        saver = tf.train.Saver()
        saver.save(session, 'zero')
        session.close()


def save_one():
    # save a 2x2 variable filled with ones
    with tf.Graph().as_default():
        session = tf.Session()
        with tf.name_scope('dummy_graph'):
            tf.Variable([[1.0, 1.0], [1.0, 1.0]], name='a', dtype=tf.float32)
        init_op = tf.global_variables_initializer()
        session.run(init_op)
        saver = tf.train.Saver()
        saver.save(session, 'one')
        session.close()


def test(boolean):
    with tf.Session() as session:
        if boolean:
            saver = tf.train.import_meta_graph('one.meta')
            saver.restore(session, './one')

            session.run(session.graph.get_operation_by_name('init'))
            tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')
        else:

            saver = tf.train.import_meta_graph('zero.meta')
            saver.restore(session, './zero')

            session.run(session.graph.get_operation_by_name('init'))
            tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')

        return session.run(tensor)

save_zero()
save_one()
print(test(False))
print(test(True))

test的调用都返回零。观察会话中的操作表明test中的会话正在两个调用中被重用,当test返回时,当会话关闭时,AFAIK不会发生这两个调用:

def test(boolean):
    with tf.Session() as session:
        if boolean:
            saver = tf.train.import_meta_graph('one.meta')
            saver.restore(session, './one')
            # contains duplicate ops (suffixed with '_1')
            [print(op.name) for op in session.graph.get_operations()]
            session.run(session.graph.get_operation_by_name('init'))
            tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')
        else:

            saver = tf.train.import_meta_graph('zero.meta')
            saver.restore(session, './zero')

            session.run(session.graph.get_operation_by_name('init'))
            tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')

        return session.run(tensor)

这是一个错误还是我错过了什么?

2 个答案:

答案 0 :(得分:2)

TL; DR: tf.Session在您的代码中对test()的两次调用之间关闭,但是您遇到了问题,因为这两个会话正在共享相同tf.Graph。使用新的tf.Graph创建每个会话以避免此问题。

特别是,当您在调用tf.train.import_meta_graph()时调用test(False)时,在tf.train.import_meta_graph()调用中调用test(True)时创建的节点仍保留在图表中。这意味着对session.graph.get_tensor_by_name('dummy_graph/a:0')的两次调用中的每一次都将返回相同的节点(在您第一次调用test()时创建)。

有几种方法可以避免这个问题。最简单的方法是使用自己的图形创建tf.Session

def test(boolean):
    # Session will use its own graph.
    with tf.Session(graph=tf.Graph()) as session:
        if boolean:
            # ...

答案 1 :(得分:1)

要向mrry回答添加更多详细信息,请参阅以下内容:

在检查点保存期间,您每次都在创建新图表,因此您要在两个检查点中保存张量dummy_graph/a

zero.data-00000-of-00001: dummy_graph/a - [0, 0, 
one.data-00000-of-00001: dummy_graph/a - [1, 1, 

在第一次加载调用期间,首先创建一个包含变量dummy_graph/a的图表,然后加载[0, 0,,然后调用init op,它将使用[0, 0,

覆盖此值

在第二次加载调用期间,您的import_meta_graph会附加到现有默认图表。由于名称冲突,它会将_1附加到节点,因此现在您的图表将包含节点dummy_graph/adummy_graph/a_1以及相应的init节点init和{{ 1}}

在第二次恢复期间,您的保护程序会将init_1检查点恢复为[1, 1, ...。然后,您拨打dummy_graph/a,这将使用init覆盖dummy_graph/a的值。然后返回[0, 0, ...

的值

请注意,在第二次恢复后,您的会话有两个变量,第二个未初始化。奇怪的是,dummy_graph/a没有显示它,即使tf.report_uninitialized_variables()会抛出sess.run('dummy_graph/a_1:0')错误,这似乎是一个错误。