我正在保存两张图表;一个具有2X2零张量,另一个具有相同大小的张量。我会根据情况恢复其中一个。
import tensorflow as tf
def save_zero():
# save a 2x2 variable filled with zeros
with tf.Graph().as_default():
session = tf.Session()
with tf.name_scope('dummy_graph'):
tf.Variable([[0.0, 0.0], [0.0, 0.0]], name='a', dtype=tf.float32)
init_op = tf.global_variables_initializer()
session.run(init_op)
saver = tf.train.Saver()
saver.save(session, 'zero')
session.close()
def save_one():
# save a 2x2 variable filled with ones
with tf.Graph().as_default():
session = tf.Session()
with tf.name_scope('dummy_graph'):
tf.Variable([[1.0, 1.0], [1.0, 1.0]], name='a', dtype=tf.float32)
init_op = tf.global_variables_initializer()
session.run(init_op)
saver = tf.train.Saver()
saver.save(session, 'one')
session.close()
def test(boolean):
with tf.Session() as session:
if boolean:
saver = tf.train.import_meta_graph('one.meta')
saver.restore(session, './one')
session.run(session.graph.get_operation_by_name('init'))
tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')
else:
saver = tf.train.import_meta_graph('zero.meta')
saver.restore(session, './zero')
session.run(session.graph.get_operation_by_name('init'))
tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')
return session.run(tensor)
save_zero()
save_one()
print(test(False))
print(test(True))
对test
的调用都返回零。观察会话中的操作表明test
中的会话正在两个调用中被重用,当test
返回时,当会话关闭时,AFAIK不会发生这两个调用:
def test(boolean):
with tf.Session() as session:
if boolean:
saver = tf.train.import_meta_graph('one.meta')
saver.restore(session, './one')
# contains duplicate ops (suffixed with '_1')
[print(op.name) for op in session.graph.get_operations()]
session.run(session.graph.get_operation_by_name('init'))
tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')
else:
saver = tf.train.import_meta_graph('zero.meta')
saver.restore(session, './zero')
session.run(session.graph.get_operation_by_name('init'))
tensor = session.graph.get_tensor_by_name('dummy_graph/a:0')
return session.run(tensor)
这是一个错误还是我错过了什么?
答案 0 :(得分:2)
TL; DR: tf.Session
在您的代码中对test()
的两次调用之间关闭,但是您遇到了问题,因为这两个会话正在共享相同tf.Graph
。使用新的tf.Graph
创建每个会话以避免此问题。
特别是,当您在调用tf.train.import_meta_graph()
时调用test(False)
时,在tf.train.import_meta_graph()
调用中调用test(True)
时创建的节点仍保留在图表中。这意味着对session.graph.get_tensor_by_name('dummy_graph/a:0')
的两次调用中的每一次都将返回相同的节点(在您第一次调用test()
时创建)。
有几种方法可以避免这个问题。最简单的方法是使用自己的图形创建tf.Session
:
def test(boolean):
# Session will use its own graph.
with tf.Session(graph=tf.Graph()) as session:
if boolean:
# ...
答案 1 :(得分:1)
要向mrry回答添加更多详细信息,请参阅以下内容:
在检查点保存期间,您每次都在创建新图表,因此您要在两个检查点中保存张量dummy_graph/a
:
zero.data-00000-of-00001: dummy_graph/a - [0, 0,
one.data-00000-of-00001: dummy_graph/a - [1, 1,
在第一次加载调用期间,首先创建一个包含变量dummy_graph/a
的图表,然后加载[0, 0,
,然后调用init op,它将使用[0, 0,
在第二次加载调用期间,您的import_meta_graph
会附加到现有默认图表。由于名称冲突,它会将_1
附加到节点,因此现在您的图表将包含节点dummy_graph/a
和dummy_graph/a_1
以及相应的init节点init
和{{ 1}}
在第二次恢复期间,您的保护程序会将init_1
检查点恢复为[1, 1, ...
。然后,您拨打dummy_graph/a
,这将使用init
覆盖dummy_graph/a
的值。然后返回[0, 0, ...
请注意,在第二次恢复后,您的会话有两个变量,第二个未初始化。奇怪的是,dummy_graph/a
没有显示它,即使tf.report_uninitialized_variables()
会抛出sess.run('dummy_graph/a_1:0')
错误,这似乎是一个错误。