Tensorflow freeze_graph脚本在使用Keras定义的模型上失败

时间:2016-06-08 13:46:08

标签: tensorflow

我正在尝试将使用Keras构建和训练的模型导出到我可以在C ++脚本中加载的protobuffer(如本例所示)。我生成了一个包含模型定义的.pb文件和一个包含检查点数据的.ckpt文件。但是,当我尝试使用freeze_graph脚本将它们合并到一个文件中时,我收到错误:

ValueError: Fetch argument 'save/restore_all' of 'save/restore_all' cannot be interpreted as a Tensor. ("The name 'save/restore_all' refers to an Operation not in the graph.")

我正在保存这样的模型:

with tf.Session() as sess:
    model = nndetector.architecture.models.vgg19((3, 50, 50))
    model.load_weights('/srv/nn/weights/scratch-vgg19.h5')
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    graph_def = sess.graph.as_graph_def()
    tf.train.write_graph(graph_def=graph_def, logdir='.',   name='model.pb', as_text=False)
    saver = tf.train.Saver()
    saver.save(sess, 'model.ckpt')

nndetector.architecture.models.vgg19((3,50,50))只是在Keras中定义的类似vgg19的模型。

我正在调用freeze_graph脚本:

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=[path-to-model.pb] --input_checkpoint=[path-to-model.ckpt] --output_graph=[output-path] --output_node_names=sigmoid --input_binary=True

如果我运行freeze_graph_test脚本,一切正常。

有谁知道我做错了什么?

感谢。

祝你好运

菲利普

修改

我已尝试打印tf.train.Saver().as_saver_def().restore_op_name,返回save/restore_all

此外,我尝试了一个简单的纯张量流示例,但仍然得到同样的错误:

a = tf.Variable(tf.constant(1), name='a')
b = tf.Variable(tf.constant(2), name='b')
add = tf.add(a, b, 'sum')

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
tf.train.Saver().save(sess, 'simple.ckpt')

而且我实际上也无法在python中恢复图形。如果我单独执行保存图形,则使用以下代码抛出ValueError: No variables to save(也就是说,如果我在同一个脚本中保存并恢复模型,一切正常)。

with gfile.FastGFile('simple_as_binary.pb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Session() as sess:
    tf.import_graph_def(graph_def)
    saver = tf.train.Saver()
    saver.restore(sess, 'simple.ckpt')

我不确定这两个问题是否相关,或者我是不是在python中没有正确恢复模型。

2 个答案:

答案 0 :(得分:8)

问题在于原始程序中这两行的顺序:

tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
tf.train.Saver().save(sess, 'simple.ckpt')

调用tf.train.Saver() 一组节点添加到图表中,其中包括一个名为"save/restore_all"的节点。但是,该程序在写出图形后将其称为,因此传递给freeze_graph.py的文件不包含这些重写所必需的节点。

反转这两行应该使脚本按预期工作:

tf.train.Saver().save(sess, 'simple.ckpt')
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)

答案 1 :(得分:3)

所以,我得到了它的工作。排序。

直接使用tensorflow.python.client.graph_util.convert_variables_to_constants代替首先保存GraphDef并将检查点保存到磁盘然后使用freeze_graph工具/脚本,我可以保存GraphDef包含图形定义和转换为常量的变量。

修改

mrry更新了他的答案,这解决了我的freeze_graph无法解决的问题,但我也会留下这个答案,以防其他人发现它有用。