测试TensorFlow对象检测API时使用tfrecord-format

时间:2018-09-03 09:25:04

标签: tensorflow

我正在为我们的项目使用“ Tensorflow对象检测API”。

经过多次尝试和错误,我成功发明了发明, 但是当我测试模型时,我认为我可以改善模型的分析时间

我的电脑是

“ Ubuntu 16.04”,

“ i7-7700k”,

“单个GPU,GTX1080ti”,

“ CUDA 9.1”,“ cuDNN 7.0.5”,

模型的平均分析时间为每张img 1.1秒(尺寸1920 * 1080,jpg格式)

由于给出了Tensorflow对象检测API教程的测试代码,因此它似乎是“读取图像文件并将其转换为np数组,而不是tf记录”

那么,...如果将test_imgs作为tf记录格式输入,可以改善模型的分析时间吗?

如果是,该怎么办?

请帮助我或给出一些提示。

这是我的代码

    TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(start_idx, end_idx) ]
    with self.detection_graph.as_default():
        with tf.Session(graph=self.detection_graph) as sess:
            # Definite input and output Tensors for detection_graph
            image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
            # Each box represents a part of the image where a particular object was detected.
            detection_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
            # Each score represent how level of confidence for each of the objects.
            # Score is shown on the result image, together with the class label.
            detection_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
            detection_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
            num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
            total_size = len(TEST_IMAGE_PATHS)
            cur_idx = 0
            while True:
                if total_size <= 0:
                    break

                cur_lst = []
                for i in range(min(my_batch_size, total_size)):
                    image_path = TEST_IMAGE_PATHS[cur_idx + i]
                    try:
                        image = Image.open(image_path)
                    except FileNotFoundError:
                        print('[{}] does not exist'.format(image_path))
                        continue

                    image_np = self.load_image_into_numpy_array(image)
                    cur_lst.append(image_np)

                cur_feed_lst = np.asarray(cur_lst)

                (boxes, scores, classes, _) = sess.run([detection_boxes, detection_scores, detection_classes, num_detections],\
                                                       feed_dict={image_tensor: cur_feed_lst})

                myboxes = np.squeeze(boxes)
                myclasses = np.squeeze(classes).astype(np.int32)
                myscores = np.squeeze(scores)
                im_width, im_height = (1920, 1080)


                if cur_feed_lst.shape[0] == 1:
                    tmp_bbox_list = []

                    for i in range(min(20, myboxes.shape[0])):
                        if myscores[i] >= self.detect_threshold:
                            ymin, xmin, ymax, xmax = tuple(myboxes[i].tolist())
                            (left, right, top, bottom) = (int(xmin * im_width), int(xmax * im_width), 
                                                          int(ymin * im_height), int(ymax * im_height))
                            # score, classN, left, right, top, bottom
                            curclassn = self.category_index[myclasses[i]]['name']
                            cur_bbox_data = '{0:.3f}\t{1}\t{2}\t{3}\t{4}\t{5}'.format(myscores[i], curclassn, left, top, right, bottom)

                            tmp_bbox_list.append(cur_bbox_data)

                    if not tmp_bbox_list:
                        pass
                    else:
                        image_path = TEST_IMAGE_PATHS[cur_idx]
                        res_list[image_path] = tmp_bbox_list

                        vis_util.visualize_boxes_and_labels_on_image_array(
                              cur_feed_lst[0],
                              np.squeeze(boxes),
                              np.squeeze(classes).astype(np.int32),
                              np.squeeze(scores),
                              self.category_index,
                              use_normalized_coordinates=True,
                              line_thickness=8)

                    pass

                else:
                    for j in range(min(my_batch_size, total_size)):
                        myboxes2 = myboxes[j]
                        myclasses2 = myclasses[j]
                        myscores2 = myscores[j]

                        tmp_bbox_list = []

                        for i in range(min(20, myboxes2.shape[0])):
                            if myscores2[i] >= self.detect_threshold:
                                ymin, xmin, ymax, xmax = tuple(myboxes2[i].tolist())
                                (left, right, top, bottom) = (int(xmin * im_width), int(xmax * im_width), 
                                                              int(ymin * im_height), int(ymax * im_height))
                                # score, classN, left, right, top, bottom
                                curclassn = self.category_index[myclasses2[i]]['name']
                                cur_bbox_data = '{0:.3f}\t{1}\t{2}\t{3}\t{4}\t{5}'.format(myscores2[i], curclassn, left, top, right, bottom)

                                tmp_bbox_list.append(cur_bbox_data)

                        if not tmp_bbox_list:
                            pass
                        else:
                            image_path = TEST_IMAGE_PATHS[cur_idx + j]
                            res_list[image_path] = tmp_bbox_list

                    pass

                # update
                cur_idx += my_batch_size
                total_size -= my_batch_size

    def load_image_into_numpy_array(self, image):
        (im_width, im_height) = image.size

        return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)

0 个答案:

没有答案