多处理池中的Tensorflow

时间:2019-03-27 07:51:13

标签: python tensorflow multiprocessing

我有一个创建的tensorflow对象,我正在使用tensorflow进行对象检测。 我在一些称为compute的方法上运行一个进程池,然后在主循环中获取要检测到每个视频帧的对象。

问题是使用串行或多处理时性能非常慢。

我不知道这是在多个进程中工作的tensorflow会话是否存在问题或存在其他任何问题?

class CarDetector(object):
    def __init__(self):

        self.car_boxes = []
        self.ped_boxes = []
        os.chdir(cwd)

        #Tensorflow localization/detection model
        # Single-shot-dectection with mobile net architecture trained on COCO dataset

        detect_model_name = 'ssd_mobilenet_v1_coco_11_06_2017'

        PATH_TO_CKPT = detect_model_name + '/frozen_inference_graph.pb'

        # setup tensorflow graph
        self.detection_graph = tf.Graph()

        # the tensorflow graph
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
               serialized_graph = fid.read()
               od_graph_def.ParseFromString(serialized_graph)
               tf.import_graph_def(od_graph_def, name='')
               self.sess = tf.Session(graph=self.detection_graph,config=tf.ConfigProto())

            self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
              # Each box represents a part of the image where a particular object was detected.
            self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
              # Each score represent how level of confidence for each of the objects.
              # Score is shown on the result image, together with the class label.
            self.scores =self.detection_graph.get_tensor_by_name('detection_scores:0')
            self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            self.num_detections =self.detection_graph.get_tensor_by_name('num_detections:0')




if __name__ == "__main__":
    video_name = '2016-11-18_07-30-01.h264'
    #video_name = 'TownCentreXVID.avi'

    cap = cv2.VideoCapture(video_name)

    det = detector.CarDetector()


    while f < frame_count:
        ret, frame = cap.read()
        f+=1
        if (f > 0):
            dst = cv2.warpPerspective(frame.copy(), matrix, (frame.shape[1], frame.shape[0]))
            pipeline(pool_1, pool_2, frame, car_tracker, ped_tracker, df_region, region_buffered, df_line_car, df_line_ped, det_cpy, dst, matrix)

并且在管道中,我正在运行张量流会话和池映射另一个名为compute的函数

def pipeline(pool_1, pool_2, img, car_tracker, ped_tracker, df_region, region_buffered, df_line_car,df_line_ped, det, dst, H):

    car_box, ped_box = det.get_localization(img)
    for trk in car_detections:
        trk = trk.astype(np.int32)
        helpers.draw_box_label(img, trk, trk[4])  # Draw the bounding boxes on the
    df_cars= pool_1.map(compute, cars)
    df_peds = pool_2.map(compute, peds)

0 个答案:

没有答案