我有两个单独的张量流程,一个用于训练模型并用tensorflow.python.client.graph_util.convert_variables_to_constants
写出graph_def,另一个用tensorflow.import_graph_def
读取graph_def。我希望第二个进程在第一个进程更新时定期重新加载graph_def。不幸的是,即使我关闭当前会话并创建一个新会话,似乎每次我读取graph_def仍会使用旧的。我也尝试用import_graph_def call
包裹sess.graph.as_default()
,但无济于事。这是我当前的graph_def加载代码:
if self.sess is not None:
self.sess.close()
self.sess = tf.Session()
graph_def = tf.GraphDef()
with open(self.graph_path, 'rb') as f:
graph_def.ParseFromString(f.read())
with self.sess.graph.as_default():
tf.import_graph_def(graph_def, name='')
答案 0 :(得分:5)
这里的问题是,当您创建一个没有参数的tf.Session
时,它会使用当前默认图。假设您没有在代码中的任何其他地方创建tf.Graph
,您将获得在流程启动时创建的全局默认图,并在所有会话之间共享。因此,with self.sess.graph.as_default():
无效。
很难从您在问题中显示的代码段推荐一个新结构 - 特别是,我不知道您是如何创建上一个图表,或者类结构是什么 - 但是一个可能性是用以下内容替换self.sess = tf.Session()
:
self.sess = tf.Session(graph=tf.Graph()) # Creates a new graph for the session.
现在with self.sess.graph.as_default():
将使用为会话创建的图表,您的程序应具有预期的效果。
稍微优先(对我而言)替代方案是明确构建图形:
with tf.Graph().as_default() as imported_graph:
tf.import_graph_def(graph_def, ...)
sess = tf.Session(graph=imported_graph)