如何提高TensorFlow object_detection模型的速度?

时间:2019-07-09 13:08:46

标签: python tensorflow tkinter object-detection cv2

我正在使用SSD_mobilenet_v1_coco模型(根据model zoo最快)来开发对象检测模型,并且已经使用TKinter为其编写了GUI。用户选择一个视频,然后将其通过模型传递给标签。标记完成后,用户可以在应用程序中查看标记的视频。但是,在处理视频时,窗口会完全冻结,直到视频结束。它的处理时间也大约是视频长度的6倍。我已经使用了与各种教程类似的代码,这些教程解释了如何使用网络摄像头运行模型,因此我不知道为什么会发生这种情况,它应该运行并实时显示给我理解。我对为什么会发生这种情况的唯一想法是,该模型尚未经过完全训练,我只完成了大约10%的工作。运行时间会随着经验减少吗?

这是我的注释代码:

    def annotate(self):
        if("annotated" in self.video_path):
            messagebox.showinfo("Error", "You can't annotate an annotated video!")
        elif(self.mode == "V" and not self.video_path is None):
            exporting = False
            MsgBox = tk.messagebox.askquestion ('Export to CSV','Do you want to export the video to CSV?',icon = 'warning')
            if MsgBox == 'yes':
               exporting = True
            else:
                b2 = tk.messagebox.askquestion('Export to CSV', "Are you sure you don't want to export the video to CSV?", icon = 'warning')
                if b2 == 'no':
                    exporting = True
            fourcc = cv2.VideoWriter_fourcc(*'MP4V')
            time = datetime.datetime.now().strftime('%Y-%m-%d %H_%M_%S')
            path = 'output/videos/annotated_' + time + '_output.mp4'
            path = os.path.abspath(path)
            out = cv2.VideoWriter(path, fourcc, 20.0, (960, 540))
            self.rewind()
            NUM_CLASSES = self.get_num_classes(self.label_map)
            detection_graph = tf.Graph()
            with detection_graph.as_default():
                od_graph_def = tf.compat.v1.GraphDef()
                with tf.io.gfile.GFile(self.model_graph, 'rb') as fid:
                    serialized_graph = fid.read()
                    od_graph_def.ParseFromString(serialized_graph)
                    tf.import_graph_def(od_graph_def, name='')
            lmap = label_map_util.load_labelmap(self.label_map)
            categories = label_map_util.convert_label_map_to_categories(lmap, max_num_classes=NUM_CLASSES, use_display_name=True)
            self.category_index = label_map_util.create_category_index(categories)
            fps = self.video.get(cv2.CAP_PROP_FPS)
            with detection_graph.as_default():
                with tf.compat.v1.Session(graph=detection_graph) as sess:
                    while not self.currentFrame is None:
                        image_np = self.get_just_frame()
                        if(image_np is None):
                            break
                        image_np_expanded = np.expand_dims(image_np, axis=0)

                        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

                        self.boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

                        self.scores = detection_graph.get_tensor_by_name('detection_scores:0')

                        self.classes = detection_graph.get_tensor_by_name('detection_classes:0')

                        num_detections = detection_graph.get_tensor_by_name(
                            'num_detections:0')

                        (self.boxes, self.scores, self.classes, num_detections) = sess.run(
                            [self.boxes, self.scores, self.classes, num_detections],
                            feed_dict={image_tensor: image_np_expanded})

                        vis_util.visualize_boxes_and_labels_on_image_array(
                            image_np,
                            np.squeeze(self.boxes),
                            np.squeeze(self.classes).astype(np.int32),
                            np.squeeze(self.scores),
                            self.category_index,
                            use_normalized_coordinates=True,
                            line_thickness=2)


                        score, hsi = self.find_highest_score(self.scores)
                        score = score * 100
                        if(score > 50):
                            score = (str(score))[:2] + "%"
                            box = self.boxes[0][hsi]
                            box = ("{}, {}, {}, {}".format(round(box[1]*960, 2), round(box[0]*540, 2), round(box[3]*960, 2), round(box[2]*540, 2))).split(', ')
                            class_ = (self.category_index[self.classes[0][hsi]])['name']
                            timestamp = (round((float)(self.true_frame_count/fps), 2))
                            timestamp = str(datetime.timedelta(seconds=timestamp))

                            self.csv_output.append((class_, score, timestamp, self.true_frame_count, box[0], box[1], box[2], box[3], self.video_path, "../../" + path))


                        # Display output
                        out.write(image_np)
            self.video.release()
            out.release()
            self.video = None

            self.set_video_path(path)
            self.video = cv2.VideoCapture(self.video_path)
            if(not self.video.isOpened()):
                raise ValueError("Unable to open video source", self.video_path)
            ret, frame = self.get_frame()
            if(ret and not frame is None):
                self.photo = PIL.ImageTk.PhotoImage(image = PIL.Image.fromarray(frame))  
                self.canvas.create_image(0, 0, image = self.photo, anchor = NW)
            anomalies_found = []
            if(exporting):
                anomalies_found = self.export_CSV()

            if(len(anomalies_found) > 0):
                message = ""
                for a in anomalies_found:
                    message = message + a + "\n"
                popup = tk.Toplevel()
                popup.wm_title("Anomalies Found!")
                m = tk.Text(popup)
                m.pack()
                m.insert(tk.END, message)
                popup.mainloop()
            else:
                messagebox.showinfo("Notification", "No anomalies detected")
            os.startfile(path)
        if(self.video_path is None):
            messagebox.showinfo("Error", "No video selected")`

这是基本视频播放的代码:

    def update(self):
        ret, frame = self.get_frame()

        if(ret and not frame is None):
            self.photo = PIL.ImageTk.PhotoImage(image = PIL.Image.fromarray(frame))  
            self.canvas.create_image(0, 0, image = self.photo, anchor = NW)

        self.parent.after(self.delay, self.update)

    def get_frame(self):
        if(self.video.isOpened()):
            ret, frame = self.video.read()
            self.currentFrame = frame
            height = None
            width = None
            if not frame is None:
                height, width, channels = frame.shape
            if(not height == 540 and not width == 960 and not frame is None):
                frame = cv2.resize(frame, (960, 540))
            if ret:
                self.true_frame_count += 1
                self.last_60_frames.append(frame)
                while(len(self.last_60_frames) > 60):
                    self.last_60_frames.pop()
                return (True, cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            else:
                self.delay = 1000000
                self.video.release()
                return (True, None)
        else:
            self.delay = 1000000
            self.video.release()
            return (True, None)

    def get_just_frame(self):
        if(self.video.isOpened() and not self.currentFrame is None):
            ret, frame = self.video.read()
            self.currentFrame = frame
            if ret:
                frame = cv2.resize(frame, (960, 540))
                self.true_frame_count += 1
                return frame
                #return cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            else:
                self.delay = 1000000
                self.video.release()
                return None
        else:
            self.delay = 1000000
            self.video.release()
            return None

如何提高模型速度并保持窗口不冻结?

0 个答案:

没有答案