了解Tensorflow对象检测API的改进版本

时间:2018-07-02 16:49:32

标签: tensorflow object-detection

我在我的项目中使用Tensorflow对象检测API,并遇到了以下链接: https://github.com/tensorflow/models/issues/3270 该代码附加在此链接上的zip文件中。我不理解的具体部分是这部分:

input_graph = tf.Graph()
with tf.Session(graph=input_graph):
    score = tf.placeholder(tf.float32, shape=(None, 1917, 90), name="Postprocessor/convert_scores")
    expand = tf.placeholder(tf.float32, shape=(None, 1917, 1, 4), name="Postprocessor/ExpandDims_1")
    for node in input_graph.as_graph_def().node:
        if node.name == "Postprocessor/convert_scores":
            score_def = node
        if node.name == "Postprocessor/ExpandDims_1":
            expand_def = node

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        dest_nodes = ['Postprocessor/convert_scores','Postprocessor/ExpandDims_1']

        edges = {}
        name_to_node_map = {}
        node_seq = {}
        seq = 0
        for node in od_graph_def.node:
            n = _node_name(node.name)
            name_to_node_map[n] = node
            edges[n] = [_node_name(x) for x in node.input]
            node_seq[n] = seq
            seq += 1

        for d in dest_nodes:
            assert d in name_to_node_map, "%s is not in graph" % d

        nodes_to_keep = set()
        next_to_visit = dest_nodes[:]
        while next_to_visit:
            n = next_to_visit[0]
            del next_to_visit[0]
            if n in nodes_to_keep:
                continue
            nodes_to_keep.add(n)
            next_to_visit += edges[n]

        nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])

        nodes_to_remove = set()
        for n in node_seq:
            if n in nodes_to_keep_list: 
                continue
            nodes_to_remove.add(n)
        nodes_to_remove_list = sorted(list(nodes_to_remove), key=lambda n: node_seq[n])

        keep = graph_pb2.GraphDef()
        for n in nodes_to_keep_list:
            keep.node.extend([copy.deepcopy(name_to_node_map[n])])

        remove = graph_pb2.GraphDef()
        remove.node.extend([score_def])
        remove.node.extend([expand_def])
        for n in nodes_to_remove_list:
            remove.node.extend([copy.deepcopy(name_to_node_map[n])])

        with tf.device('/gpu:0'):
            tf.import_graph_def(keep, name='')
        with tf.device('/cpu:0'):
            tf.import_graph_def(remove, name='')

通过将操作正确分配给GPU和CPU,减少了处理每个图像所需的时间。我有一个基本的想法,那就是它试图在CPU和GPU上分配操作,但是对两个图形,它们的结构和工作方式的解释将非常有帮助。谢谢!

1 个答案:

答案 0 :(得分:0)

我对这段代码的理解是:

  • 它创建带有两个占位符'Postprocessor/convert_scores''Postprocessor/ExpandDims_1'的图形。
  • 将其转换为graph_def,并保留与占位符相对应的节点。

    • 这2个节点对应于模型输出的1917个盒子,第一个是类概率,第二个是盒子坐标。
  • 它将创建第二张图并加载经过训练的模型。

  • 它列出了图中的所有节点以及它们之间的连接方式。
  • 列出所有连接到'Postprocessor/convert_scores''Postprocessor/ExpandDims_1'的节点,并将它们存储在保留列表中。
  • 列出所有不在nodes_to_keep_list中的节点,并将它们存储在nodes_to_remove_list中。

  • 然后创建一个图def,并用所有nodes_to_keep_list节点的副本填充它。

  • 然后是第二个图def,其中包含所有nodes_to_remove_list节点的副本。

  • 最后,它同时加载两个图形定义,第一个使用设备'/gpu:0',第二个使用设备'/cpu:0'

正如作者所言,这样做的目的是在CNG上运行CNN,在CPU上进行后处理,因为CNN的处理速度更快。 如果您查看mobilenet + SSD,您会看到该模型输出了一堆盒子(1917年),然后对这些盒子进行了相当复杂的(至少从图形的角度来看)后处理,以提供最终输出({ {1}},detection_boxesdetection_scoresdetection_classes)。

在这段代码中不可见,但是稍后使用占位符将num_detections图的输出插入keep图中。执行过程分为2个步骤(两次调用remove

sess.run()

编辑

(score, expand) = sess.run([score_out, expand_out], feed_dict={image_tensor: image_np_expanded}) (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={score_in:score, expand_in: expand}) print 'Iteration %d: %.3f sec'%(i, time.time()-start_time) 的值来自原始图形,使用不同的模型,它将是不同的值,但是即使是不同的节点等……这就是为什么此解决方案比真正的解决方案更像是hack,因为需要针对您想要将其应用于...的每个新模型进行量身定制。

不久前,我看了这张图,我认为模型输出了一堆具有一定大小或长宽比的盒子,以及另一堆长宽比不同的盒子,等等,所有这些都融合在一起,最终并带有这个1917方框图。

1917只是操作的名称,因为它没有在图中命名。 ExpandDims之所以存在,是因为该范围内的图形中可能已经存在一个。至于为什么要特别指定这些节点,这只是作者在研究这些性能问题后做出的任意选择。基本上,最慢的部分是在这些节点之后。但是他可以选择稍有不同的节点,例如,在执行_1操作之前,它会执行相同的操作。这些特定节点的实际目的与他在这里所做的无关。在ExpandDims的情况下,这是一个非常平凡的操作,只是添加了尺寸1。