我正在加载已为序列服务序列化的张量流图。我试图在一些可以追溯到训练数据的代码中使用它。该代码创建了tensorflow操作来遍历tfrecords训练数据 - 但由于某种原因,一旦我使用服务图创建会话,我就无法使用训练中的数据集迭代器?这就是发生的事情:
写出图表的代码:
output_graph_def = \
tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=tf.get_default_graph().as_graph_def(),
output_node_names=['predict'])
with tf.gfile.GFile('graph.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
加载图表:
def load_graph_pb(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def into default Graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
op_dict=None,
producer_op_list=None
)
return graph
tf.reset_default_graph()
graph = load_graph_pb('graph.pb')
使用该图表创建会话:
sess = tf.Session(graph=graph)
然后创建一个数据集迭代器:
def batchIter(sess, datafile):
dataset = tf.data.TFRecordDataset(datafile) \
...
tf_iter = dataset.make_initializable_iterator()
get_next = tf_iter.get_next()
sess.run(tf_iter.initializer)
while True:
batch = sess.run(get_next)
yield batch
但是当我尝试使用它时,
iter = batchIter(sess, 'train.tfrec')
batch = next(iter)
我得到了
ValueError: Fetch argument <tf.Operation 'MakeIterator' type=MakeIterator> cannot be interpreted as a Tensor. (Operation name: "MakeIterator"
op: "MakeIterator"
input: "ParallelMapDataset"
input: "Iterator"
attr {
key: "_class"
value {
list {
s: "loc:@Iterator"
}
}
}
is not an element of this graph.)
我想我早点打电话给sess(graph=graph)
?在做这个之前我必须要做iter操作,但是我还要确保他们进入这个graph
我加载了吗?我想我知道问题是什么,欢迎优雅的解决方案:)