一个Tensorflow会话中的多个图

时间:2018-04-14 13:40:40

标签: tensorflow object-detection drone parrot

我目前正在尝试实施一个代码,允许我的无人机使用tensorflow在室内导航。我需要在一个会话中运行两个模型。

一个用于主导航 - 这是一个重新训练的Inception V3模型,负责对走廊图像进行分类并执行向前,向左或向右决策 - 第二个是对象跟踪模型,它将跟踪对象并计算相对到相机的距离。

我不知道如何在一个会话中使用多个图形,所以我尝试在循环中创建一个单独的会话,这会产生很大的开销并导致我的脚本以0 FPS运行。

def inception_model():
# Graph for the InceptionV3 Model
graph = load_graph('inception_v3_frozen/inception_v3_2016_08_28_frozen.pb')

with tf.Session(graph = graph) as sess:
    while camera.isOpened():
        ok, img = camera.read()
        cv.imwrite("frame_temp.jpeg", img)
        t = read_tensor_from_image('frame_temp.jpeg')

        input_layer = "input"
        output_layer = "InceptionV3/Predictions/Reshape_1"

        input_name = "import/" + input_layer
        output_name = "import/" + output_layer

        input_operation = graph.get_operation_by_name(input_name)
        output_operation = graph.get_operation_by_name(output_name)

        results = sess.run(output_operation.outputs[0], {
            input_operation.outputs[0] : t
        })
        results = np.squeeze(results)

        top_k = results.argsort()[-5:][::-1]
        for i in top_k:
            print(labels[i], results[i])

# inception_model()
with tf.Session(graph = object_detection_graph) as sess:
    while camera.isOpened():
        ok, img = camera.read()
        cv.imwrite("frame_temp.jpeg", img)
        img = np.array(img)
        rows = img.shape[0]
        cols = img.shape[1]

        inp = cv.resize(img, (299, 299))

        # inception_model()
        # # Graph for the InceptionV3 Model
        # graph = load_graph('inception_v3_frozen/inception_v3_2016_08_28_frozen.pb')

        # t = read_tensor_from_image('frame_temp.jpeg')

        # input_layer = "input"
        # output_layer = "InceptionV3/Predictions/Reshape_1"

        # input_name = "import/" + input_layer
        # output_name = "import/" + output_layer

        # input_operation = graph.get_operation_by_name(input_name)
        # output_operation = graph.get_operation_by_name(output_name)

        # with tf.Session(graph = graph) as sess:
        #     results = sess.run(output_operation.outputs[0], {
        #         input_operation.outputs[0] : t
        #     })
        # results = np.squeeze(results)

        # top_k = results.argsort()[-5:][::-1]
        # for i in top_k:
        #     print(labels[i], results[i])


        inp = inp[:, :, [2, 1, 0]]  # BGR2RGB


        # Run the model
        out = sess.run([object_detection_graph.get_tensor_by_name('num_detections:0'),
                        object_detection_graph.get_tensor_by_name('detection_scores:0'),
                        object_detection_graph.get_tensor_by_name('detection_boxes:0'),
                        object_detection_graph.get_tensor_by_name('detection_classes:0')],
                    feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})

1 个答案:

答案 0 :(得分:3)

您不必在每次迭代时创建新会话。创建它们一次并继续调用它们的run方法。 Tensorflow支持多个活动会话。

另一种选择是拥有一个Graph对象和一个Session。该图表可以将两个模型都包含为断开连接的子图。当您在Session.run()中要求张量时,Tensorflow将仅运行计算您要求的张量所需的量。因此,另一个子图将不会运行(尽管需要一些,可能是非常小的时间来修剪它)