从预训练模型进行微调后,在TensorFlow模型中丢失输出节点名称

时间:2019-04-18 02:50:07

标签: python ubuntu tensorflow machine-learning pre-trained-model

我遵循https://tensorflow-object-detection-api-tutorial.readthedocs.io上的教程,对预先训练的模型进行微调,以检测图像中的新对象。预先训练的模型是 ssd_inception_v2_coco

几千步之后,我成功地训练并评估了模型,损失从26下降到1。但是,我无法使用以下代码创建冻结的模型:

#this code runs in model dir
import tensorflow as tf

#make .pb file from model at step 1000
saver = tf.train.import_meta_graph(
        './model.ckpt-1000.meta', clear_devices=True)

graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
saver.restore(sess, "./model.ckpt-1000")

#node names
i=0
for n in tf.get_default_graph().as_graph_def().node:
  print(n.name,i);    
  i+=1
#end for
print("total:",i);

output_node_names=[
  "detection_boxes","detection_classes",
  "detection_scores","num_detections"
];
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,input_graph_def,output_node_names);

#save to .pb file
output_graph="./model.pb"
with tf.gfile.GFile(output_graph, "wb") as f:
  f.write(output_graph_def.SerializeToString());
#end with

sess.close();

错误是:

enter image description here

看来,经过微调的模型已经失去了其输出节点名称。原始预训练模型中有以下输出节点名称(将上面代码中的检查点文件更改为原始训练模型中的检查点文件): detection_boxes,detection_classes,detection_scores和num_detections 。输出节点名称与原始节点名称完全相同,这是它们的索引(来自上面的节点名称“ for”循环):

enter image description here

我的问题是如何保留原始训练后的模型中的输出节点名称?节点名称是用代码定义的,但是这里没有代码,只有一些配置和文件“ train.py”。

PS。在total_loss之后有一个叫做summary_op的东西,但是我不知道它是否是output(?):

enter image description here

1 个答案:

答案 0 :(得分:0)

为了具有' image_tensor '(输入),以及其他输出节点名称' detection_boxes ',' detection_classes ','< strong> detection_scores ”,“ num_detections ”,使用 tensorflow / models / research / object_detection 中名为“ export_inference_graph.py ”的实用程序脚本>”。该脚本甚至优化了冻结图(冻结模型)以进行推理。根据我的测试模型进行的检查,节点数量从26,000个减少到5,000个;这对于推断速度非常有用。

以下是export_inference_graph.py的链接: https://github.com/tensorflow/models/blob/0558408514dacf2fe2860cd72ac56cbdf62a24c0/research/object_detection/export_inference_graph.py

如何运行:

#bash command
python3 export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path PATH_TO_PIPELINE.config \
--trained_checkpoint_prefix PATH_TO/model.ckpt-NUMBER \
--output_directory PATH_TO_NEW_DIR 

有问题的.pb创建代码仅适用于从头开始创建且具有手动定义的节点名称的模型,对于从TensorFlow Model Zoo https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md下载的预训练模型进行微调的模型检查点,它不会工作!