使用对象检测tensorflow API对录制的视频进行预测

时间:2018-10-13 17:32:49

标签: tensorflow object-detection object-detection-api

我正在尝试读取视频文件(使用opencv),使用tensorflow的对象检测API遍历所有帧以进行预测和边界框,然后将预测的帧(带有框)写入新的视频文件。我对object_detection_tutorial.ipynb进行了一些修改,以捕获视频帧并在从冻结图(经过训练)中加载的fast-rcnn-inception-resnet-v2中对其进行处理。

我在装有Windows 10和56GB RAM的云机中使用tesla P100 gpu。也使用tensorflow-gpu。

运行代码时,每帧需要0.5秒。 tesla P100的速度正常吗?还是我在代码中做错了一些以使其变慢?

此代码仅是测试,因为稍后我将不得不在实时视频预测任务中使用它。如果使用tensorflow API每帧0.5秒是预期速度,我想我将无法在任务中使用它:(

所以,在运行它之后,我得到了以下运行时间

处理框号1.0

捕获视频帧0.0的时间

预测时间为0.49225664138793945

在框架0.14833950996398926中生成框的时间

在视频文件0.04687023162841797中写帧的时间

循环中的总时间0.6874663829803467

如您所见,使用CPU(opencv)的代码运行很快。但是当我使用GPU时,仅在预测任务(用于sess.run中)就需要将近0.5秒。

有什么建议吗?先感谢您。贝娄遵循我的代码

从distutils.version导入StrictVersion     将numpy导入为np     导入操作系统     导入six.moves.urllib作为urllib     导入系统     导入tarfile     将tensorflow作为tf导入     导入压缩文件     导入时间

from collections import defaultdict
from io import StringIO
#from matplotlib import pyplot as plt
from PIL import Image

import cv2
from imutils import paths

import re

#This is needed since the code is stored in the object_detection    folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops

if StrictVersion(tf.__version__) < StrictVersion('1.9.0'):
  raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!')


from utils import label_map_util

from utils import visualization_utils as vis_util

#Detection using tensorflow inside write_video function

def write_video():

    filename = 'output/teste_v2.avi'
    codec = cv2.VideoWriter_fourcc('W', 'M', 'V', '2')
    cap = cv2.VideoCapture('pneu_trim2.mp4')
    framerate = round(cap.get(5),2)
    w = int(cap.get(3))
    h = int(cap.get(4))
    resolution = (w, h)

    VideoFileOutput = cv2.VideoWriter(filename, codec, framerate, resolution)    

    ################################
    # # Model preparation 

    # ## Variables
    # 
    # Any model exported using the `export_inference_graph.py` tool can be loaded here simply by changing `PATH_TO_FROZEN_GRAPH` to point to a new .pb file.  
    # 


    # What model to download.
    MODEL_NAME = 'training/pneu_incep_step_24887'
    print("loading model from " + MODEL_NAME)

    # Path to frozen detection graph. This is the actual model that is used for the object detection.
    PATH_TO_FROZEN_GRAPH = MODEL_NAME + '/frozen_inference_graph.pb'

    # List of the strings that is used to add correct label for each box.
    PATH_TO_LABELS = os.path.join('data', 'object-detection.pbtxt')

    NUM_CLASSES = 5


    # ## Load a (frozen) Tensorflow model into memory.

    time_graph = time.time()
    print('loading graphs')
    detection_graph = tf.Graph()
    with detection_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
    print("tempo build graph = " + str(time.time() - time_graph))

    # ## Loading label map

    label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)

    ################################

    with tf.Session(graph=detection_graph) as sess:
        with detection_graph.as_default():
            while (cap.isOpened()):
              time_loop = time.time()
              print('processing frame number: ' + str(cap.get(1)))
              time_captureframe = time.time()
              ret, image_np = cap.read()
              print("time to capture video frame = " + str(time.time() - time_captureframe))
              if (ret != True):
                  break
              # the array based representation of the image will be used later in order to prepare the
              # result image with boxes and labels on it.
              #image_np = load_image_into_numpy_array(image)
              # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
              image_np_expanded = np.expand_dims(image_np, axis=0)
              image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
              # Each box represents a part of the image where a particular object was detected.
              boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
              # Each score represent how level of confidence for each of the objects.
              # Score is shown on the result image, together with the class label.
              scores = detection_graph.get_tensor_by_name('detection_scores:0')
              classes = detection_graph.get_tensor_by_name('detection_classes:0')
              num_detections = detection_graph.get_tensor_by_name('num_detections:0')
              # Actual detection.
              time_prediction = time.time()
              (boxes, scores, classes, num_detections) = sess.run(
                  [boxes, scores, classes, num_detections],
                  feed_dict={image_tensor: image_np_expanded})
              print("time to predict = " + str(time.time() - time_prediction))
              # Visualization of the results of a detection.
              time_visualizeboxes = time.time()
              vis_util.visualize_boxes_and_labels_on_image_array(
                  image_np,
                  np.squeeze(boxes),
                  np.squeeze(classes).astype(np.int32),
                  np.squeeze(scores),
                  category_index,
                  use_normalized_coordinates=True,
                  line_thickness=8)
              print("time to generate boxes in a frame = " + str(time.time() - time_visualizeboxes))


              time_writeframe = time.time()
              VideoFileOutput.write(image_np)
              print("time to write a frame in video file = " + str(time.time() - time_writeframe))

              print("total time in the loop = " + str(time.time() - time_loop))

    cap.release()
    VideoFileOutput.release()
    print('done')

1 个答案:

答案 0 :(得分:0)

实际上,问题出在您使用的模型上。 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 基本上,模型Faster-rcnn-inception-resnet-v2将花费更多时间。 您可以参考链接了解模型的速度