我目前正在尝试将训练有素的TensorFlow模型导出为ProtoBuf文件,以便在Android上使用TensorFlow C ++ API。因此,我正在使用freeze_graph.py
脚本。
我使用tf.train.write_graph
导出了我的模型:
tf.train.write_graph(graph_def, FLAGS.save_path, out_name, as_text=True)
我正在使用tf.train.Saver
保存的检查点。
我按照脚本顶部的描述调用freeze_graph.py
。编译后,我运行
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=<path_to_protobuf_file> \
--input_checkpoint=<model_name>.ckpt-10000 \
--output_graph=<output_protobuf_file_path> \
--output_node_names=dropout/mul_1
这给了我以下错误消息:
TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.
由于错误状态我在导出的模型中没有张量save/Const:0
。但是,freeze_graph.py
的代码表示可以通过标记filename_tensor_name
指定此张量名称。不幸的是,我找不到任何关于张量应该是什么以及如何为我的模型正确设置它的信息。
有人可以告诉我如何在导出的ProtoBuf模型中生成save/Const:0
张量或如何正确设置标记filename_tensor_name
?
答案 0 :(得分:6)
--filename_tensor_name
标志用于指定为模型构建tf.train.Saver
时创建的占位符张量的名称。*
在原始程序中,您可以打印出saver.saver_def.filename_tensor_name
的值,以获取此标志应传递的值。您可能还需要打印saver.saver_def.restore_op_name
的值以获取--restore_op_name
标记的值(因为我怀疑图表的默认值不正确)。
或者,tf.train.SaverDef
protocol buffer包含重建这些标志的相关信息所需的所有信息。如果您愿意,可以将saver.saver_def
写入文件,并将该文件的名称作为--input_saver
标记传递给freeze_graph.py
。
* tf.train.Saver
的默认名称范围为"save/"
,占位符为actually a tf.constant()
,其名称默认为"Const:0"
,这解释了为什么该标记默认为{{1 }}
答案 1 :(得分:2)
我注意到当我的代码排列如下时,我发生了错误:
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
sess.run(init)
在我更改了代码布局之后,它工作了:
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
sess.run(init)
我不确定为什么会这样。 @mrry你能解释一下吗?
答案 2 :(得分:1)
@ Drag0答案的一些后续操作,以及为什么新的代码布局可修复错误。
调用saver = tf.train.Saver()
时,将与tf.train.Saver()
相关的不同变量(例如'save/Const:0'
)添加到默认图形。
在第一个代码排列中,图形保存之前没有tf.train.Saver()
变量。在第二种代码排列中,它被保存下来,因此操作save/Const
将存在于图形中。
答案 3 :(得分:0)
在最新的freeze_graph.py中应该没有问题,因为我可以看到这些已被删除:
del restore_op_name, filename_tensor_name # Unused by updated loading code.
source:freeze_graph.py
在早期版本中,它使用restore_op来恢复模型
sess.run([restore_op_name], {filename_tensor_name: input_checkpoint})
因此,对于以前的版本,如果您在实例化saver op之前在.pb文件中编写图形,则会出现问题。例如:
tf.train.write_graph(sess.graph_def, "./logs", "test2.pb", False)
saver = tf.train.Saver()
saver.save(sess, "./logs/hello_ck.ckpt", meta_graph_suffix='meta', write_meta_graph=True)
这是因为图形不会有任何保存/恢复操作来恢复模型。要解决此问题,请在保存.ckpt文件后写入图形
saver = tf.train.Saver()
saver.save(sess, "./logs/hello_ck.ckpt", meta_graph_suffix='meta', write_meta_graph=True)
tf.train.write_graph(sess.graph_def, "./logs", "test2.pb", False)
@mrry,如果我解释错误,请指导。我最近才开始深入研究tensorflow代码。