如何对冻结模型服务图创建的张量流图进行广告操作?

时间:2018-04-16 06:11:30

标签: tensorflow

我正在加载已为序列服务序列化的张量流图。我试图在一些可以追溯到训练数据的代码中使用它。该代码创建了tensorflow操作来遍历tfrecords训练数据 - 但由于某种原因,一旦我使用服务图创建会话,我就无法使用训练中的数据集迭代器?这就是发生的事情:

写出图表的代码:

output_graph_def = \
        tf.graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=tf.get_default_graph().as_graph_def(),
            output_node_names=['predict'])

 with tf.gfile.GFile('graph.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

加载图表:

def load_graph_pb(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # import graph_def into default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            op_dict=None,
            producer_op_list=None
        )
    return graph

tf.reset_default_graph()
graph = load_graph_pb('graph.pb')

使用该图表创建会话:

 sess = tf.Session(graph=graph)

然后创建一个数据集迭代器:

def batchIter(sess, datafile):
    dataset = tf.data.TFRecordDataset(datafile) \
              ...
    tf_iter = dataset.make_initializable_iterator()
    get_next = tf_iter.get_next()
    sess.run(tf_iter.initializer)
    while True:
        batch = sess.run(get_next)
        yield batch

但是当我尝试使用它时,

iter = batchIter(sess, 'train.tfrec')
batch = next(iter)

我得到了

ValueError: Fetch argument <tf.Operation 'MakeIterator' type=MakeIterator> cannot be interpreted as a Tensor. (Operation name: "MakeIterator"
op: "MakeIterator"
input: "ParallelMapDataset"
input: "Iterator"
attr {
  key: "_class"
  value {
    list {
      s: "loc:@Iterator"
    }
  }
}
 is not an element of this graph.)

我想我早点打电话给sess(graph=graph)?在做这个之前我必须要做iter操作,但是我还要确保他们进入这个graph我加载了吗?我想我知道问题是什么,欢迎优雅的解决方案:)

0 个答案:

没有答案