我有一个创建的tensorflow对象,我正在使用tensorflow进行对象检测。 我在一些称为compute的方法上运行一个进程池,然后在主循环中获取要检测到每个视频帧的对象。
问题是使用串行或多处理时性能非常慢。
我不知道这是在多个进程中工作的tensorflow会话是否存在问题或存在其他任何问题?
class CarDetector(object):
def __init__(self):
self.car_boxes = []
self.ped_boxes = []
os.chdir(cwd)
#Tensorflow localization/detection model
# Single-shot-dectection with mobile net architecture trained on COCO dataset
detect_model_name = 'ssd_mobilenet_v1_coco_11_06_2017'
PATH_TO_CKPT = detect_model_name + '/frozen_inference_graph.pb'
# setup tensorflow graph
self.detection_graph = tf.Graph()
# the tensorflow graph
with self.detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
self.sess = tf.Session(graph=self.detection_graph,config=tf.ConfigProto())
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
self.scores =self.detection_graph.get_tensor_by_name('detection_scores:0')
self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
self.num_detections =self.detection_graph.get_tensor_by_name('num_detections:0')
if __name__ == "__main__":
video_name = '2016-11-18_07-30-01.h264'
#video_name = 'TownCentreXVID.avi'
cap = cv2.VideoCapture(video_name)
det = detector.CarDetector()
while f < frame_count:
ret, frame = cap.read()
f+=1
if (f > 0):
dst = cv2.warpPerspective(frame.copy(), matrix, (frame.shape[1], frame.shape[0]))
pipeline(pool_1, pool_2, frame, car_tracker, ped_tracker, df_region, region_buffered, df_line_car, df_line_ped, det_cpy, dst, matrix)
并且在管道中,我正在运行张量流会话和池映射另一个名为compute的函数
def pipeline(pool_1, pool_2, img, car_tracker, ped_tracker, df_region, region_buffered, df_line_car,df_line_ped, det, dst, H):
car_box, ped_box = det.get_localization(img)
for trk in car_detections:
trk = trk.astype(np.int32)
helpers.draw_box_label(img, trk, trk[4]) # Draw the bounding boxes on the
df_cars= pool_1.map(compute, cars)
df_peds = pool_2.map(compute, peds)