根据我的理解,tensorflow的freeze_graph.py应该支持新的检查点格式,我应该能够使用像
这样的东西freeze_graph.py --input_saver ./checkpoints/model-49-295 --output_graph ./graph.pb --output_node_names "predictions:0"
要清楚,
ls ./checkpoints
checkpoint
model-49-295.data-00000-of-00001
model-49-295.index
model-49-295.meta
然而,当我这样做时,我收到以下错误:
Traceback (most recent call last):
File "~/.local/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 255, in <module>
app.run(main=main, argv=[sys.argv[0]] + unparsed)
File "~/.local/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "~/.local/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 187, in main
FLAGS.variable_names_blacklist)
File "~/.local/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 165, in freeze_graph
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
File "~/.local/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py", line 134, in _parse_input_graph_proto
text_format.Merge(f.read(), input_graph_def)
File "~/.local/lib/python3.5/site-packages/tensorflow/python/lib/io/file_io.py", line 125, in read
pywrap_tensorflow.ReadFromStream(self._read_buf, length, status))
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "~/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.FailedPreconditionError: .
我对此感到很困惑,因为.
似乎不是一个非常有用的错误代码,而且我可以找到的所有对FailedPreconditionError的引用都有FailedPreconditionError: Attempting to use uninitialized value ...
任何人都知道这里发生了什么?
答案 0 :(得分:0)
查看来自freeze_graph.py的代码我不太确定它是否支持新格式,或者至少我无法弄清楚它是如何实现的,即使我已经看到很多地方声称它确实。无论如何,我现在的解决方法是编写一个基本相同的简单脚本,但实际上正确地加载了检查点:
import tensorflow as tf
from tensorflow.python.framework import graph_util
from google.protobuf import text_format
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./checkpoints/model-49-295.meta', clear_devices=True)
saver.restore(sess, './checkpoints/model-49-295')
graph_def = sess.graph.as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['predictions'])
with tf.gfile.GFile('./graph.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())
答案 1 :(得分:0)
从调用堆栈,看起来像GraphDef .pb文件的解析失败。不幸的是,错误信息并不是非常有用或信息丰富!
我的猜测是你需要传入--input_binary=true
作为参数,因为默认情况下它假设输入图存储为文本protobuf。