我一直在尝试冻结模型的tensorflow
图。
使用tensorflow
给定的工具tensorflow.python.tools.freeze_graph()
时,我遇到以下解析错误
INFO:tensorflow:Restoring parameters from /mnt/ds3lab-scratch/rahimit/Visualization_distill_pub/Results/checkpoint/model_cutoutsize_1arcmin_UV_VIS_strech_asinh_factor20.01/model_9.ckpt
---------------------------------------------------------------------------
DecodeError Traceback (most recent call last)
<ipython-input-6-5dbcfb2c2b04> in <module>()
----> 1 freeze_graph_test(modelpath)
<ipython-input-5-5be58805bd76> in freeze_graph_test(model_folder)
19
20 initializ = tf.global_variables_initializer()
---> 21 freeze_graph(input_graph =model_folder+'train.pb' ,input_saver=saver, input_checkpoint=input_checkpoint, clear_devices=True ,input_binary=True, initializer_nodes= initializ , output_graph='testgraph.pd.modelzoo',output_node_names=output_node_names, restore_op_name= 'save/restore_all', filename_tensor_name='save/Const:0')
~/ENV/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py in freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, output_node_names, restore_op_name, filename_tensor_name, output_graph, clear_devices, initializer_nodes, variable_names_whitelist, variable_names_blacklist, input_meta_graph, input_saved_model_dir, saved_model_tags, checkpoint_version)
229 input_saved_model_dir, saved_model_tags).graph_def
230 elif input_graph:
--> 231 input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
232 input_meta_graph_def = None
233 if input_meta_graph:
~/ENV/lib/python3.5/site-packages/tensorflow/python/tools/freeze_graph.py in _parse_input_graph_proto(input_graph, input_binary)
170 with gfile.FastGFile(input_graph, mode) as f:
171 if input_binary:
--> 172 input_graph_def.ParseFromString(f.read())
173 else:
174 text_format.Merge(f.read(), input_graph_def)
DecodeError: Error parsing message
我要运行的代码如下
def freeze_graph_test(model_folder):
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
output_node_names = 'gen/d8/weight'
clear_devices = True
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
graph = tf.get_default_graph()
with tf.Session() as sess:
restor_all = saver.restore(sess, input_checkpoint)
initializ = tf.global_variables_initializer()
freeze_graph(input_graph=model_folder+'train.pb', input_saver=saver, input_checkpoint=input_checkpoint, clear_devices=True, input_binary=True, initializer_nodes=initializ, output_graph='testgraph.pd.modelzoo', output_node_names=output_node_names, restore_op_name='save/restore_all', filename_tensor_name='save/Const:0')
有人遇到类似问题或知道如何解决吗?