如何从保存的graph.pb中获取Session的Graph对象

时间:2017-04-27 05:49:39

标签: python tensorflow

我创建了张量流Graph。我可以加载它,例如

with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

将protobuffer文件中定义的图形暂时放置为默认图形。如果我现在创建一个会话,该图将用作当前图形。

尝试将序列化的graph_def对象保存到变量并启动Session

with tf.Session(graph=graph_def) as sess:

以预期错误结束

TypeError: graph must be a tf.Graph, but got <class 'tensorflow.core.framework.graph_pb2.GraphDef'>

我有一个用例,我必须在多个图之间进行更改。使用所提出的方法,我可以清除默认图并加载一个新图,其缺点是必须重复调用导入函数。

问题是,从我的graph.pb开始,Graph对象my_graph是如何获得的,因此可以使用

with tf.Session(graph=my_graph) as sess:

并创建会话而无需从graph.pb文件重新加载图表?

1 个答案:

答案 0 :(得分:3)

您可以创建自己的图表并将其设置为导入操作的默认值:

import tensorflow as tf
graph1 = tf.Graph()
graph2 = tf.Graph()
with graph1.as_default():
    tf.import_graph_def(graph_def1) # graph_def1 loaded somewhere

with graph2.as_default():
    tf.import_graph_def(graph_def2) # graph_def2 loaded somewhere

session1 = tf.Session(graph=graph1)
session2 = tf.Session(graph=graph2)