我创建了张量流Graph
。我可以加载它,例如
with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
将protobuffer文件中定义的图形暂时放置为默认图形。如果我现在创建一个会话,该图将用作当前图形。
尝试将序列化的graph_def
对象保存到变量并启动Session
为
with tf.Session(graph=graph_def) as sess:
以预期错误结束
TypeError: graph must be a tf.Graph, but got <class 'tensorflow.core.framework.graph_pb2.GraphDef'>
我有一个用例,我必须在多个图之间进行更改。使用所提出的方法,我可以清除默认图并加载一个新图,其缺点是必须重复调用导入函数。
问题是,从我的graph.pb
开始,Graph
对象my_graph
是如何获得的,因此可以使用
with tf.Session(graph=my_graph) as sess:
并创建会话而无需从graph.pb
文件重新加载图表?
答案 0 :(得分:3)
您可以创建自己的图表并将其设置为导入操作的默认值:
import tensorflow as tf
graph1 = tf.Graph()
graph2 = tf.Graph()
with graph1.as_default():
tf.import_graph_def(graph_def1) # graph_def1 loaded somewhere
with graph2.as_default():
tf.import_graph_def(graph_def2) # graph_def2 loaded somewhere
session1 = tf.Session(graph=graph1)
session2 = tf.Session(graph=graph2)