加载并冻结重新训练的检测rcnn模型

时间:2019-07-18 18:59:50

标签: python-3.x tensorflow

使用tensorflow-gpu版本1.12,python3.6

我尝试通过简单地加载frozen_inference_graph.pb和相应的label_map.pbtxt来从model-zoo运行多个检测模型,而没有任何问题。

按如下方式加载冻结的模型

def LoadFrozenGraph():
    detectionGraph = tf.Graph()
    with detectionGraph.as_default():
        odGraphDef = tf.GraphDef()
        with tf.gfile.GFile(pathToFrozenGraph, 'rb') as fid:
            serializedGraph = fid.read()
            odGraphDef.ParseFromString(serializedGraph)
            tf.import_graph_def(odGraphDef, name='')

    return detectionGraph

这次,我尝试使用CloudAnnotations-Custom_Training重新训练模型faster_rcnn_resnet101_coco,在这里我只提供模型路径和配置。输出是检查点文件,例如model.ckpt-5000', and a label_map.pbtxt`。再培训分为4个班级。

现在,我尝试使用以下方式加载图形:

    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(pathToCheckpoint + 'model.ckpt-5000.meta')
        saver.restore(sess, tf.train.latest_checkpoint(pathToCheckpoint))

        graph = tf.get_default_graph()

,然后加载节点(“ image_tensor:0”,“ detection_boxes:0”,..)。但是,不存在这样的节点。

所以我尝试使用以下方法冻结模型:

python export_inference_graph.py --input_type=image_tensor --pipeline_config_path='faster_rcnn_resnet101_coco.config' --trained_checkpoint_prefix='model.ckpt-5000' --output_directory='./EliaTemp'

但出现以下不匹配错误:

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. 
This is most likely due to a mismatch between the current graph and the graph from the checkpoint. 
Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Assign requires shapes of both tensors to match. lhs shape= [2048,5] rhs shape= [2048,3]

那是在我将.config中的类数从90更改为4之后。

编辑:如果将其更改为2类,则冻结模型将成功创建。

有人可以提供一些有关如何加载检查点(例如加载冻结模型)的信息吗?还是鉴于上述错误,如何创建冻结模型?

0 个答案:

没有答案