如何为我的实时对象检测API提高输入视频的分辨率?

时间:2019-11-07 01:05:02

标签: tensorflow resolution object-detection-api

我无法提高输入视频中实时对象检测的分辨率。我尝试增加读入视频的输入分辨率,但输出给我一个错误。

我正在使用网络摄像头程序流式传输视频

唯一有效的尺寸是640 x480。任何较大的尺寸都会给我输出错误。分辨率将一直保持1920 x 1080,直到在帧窗口之前,然后切换回640 x 480,并在保存视频时出现错误。

import numpy as np
import tensorflow as tf
from object_detection.utils import visualization_utils as vis_util 
import cv2 as cv
from time import time

import serial

refPoints = []

# draw ROI
def image_crop(event, x, y, flags, param):
    global refPoints

    if event == cv.EVENT_LBUTTONDOWN:
        refPoints = [(x, y)]
    elif event == cv.EVENT_LBUTTONUP:
        refPoints.append((x, y))


# run inference on single image
def run_inference_for_single_image(image, graph, sess):
    with graph.as_default():
        boxes = tf.get_default_graph().get_tensor_by_name('detection_boxes:0')
        scores = tf.get_default_graph().get_tensor_by_name('detection_scores:0')
        classes = tf.get_default_graph().get_tensor_by_name('detection_classes:0')
        image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

        rboxes, rscores, rclasses = sess.run([boxes, scores, classes], feed_dict={image_tensor: np.expand_dims(image, 0)})

    return rboxes[0], rscores[0], rclasses[0]



def main():

    # 0 - load model
    PATH_TO_MODEL = 'C:\\frozen_inference_graph.pb'
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    sess = tf.Session(graph=detection_graph)

    # 1 - define video streamer & serial port
    streamer = 0
    out = cv.VideoWriter('output.avi', cv.VideoWriter_fourcc(*'XVID'), 20, (1920, 1080)) #previously 25.0, (640, 480)
    port = 'COM4'
    ser = serial.Serial(port, 9600)

    # 2 - draw region of interest (ROI)
    cap = cv.VideoCapture(cv.CAP_DSHOW)
    clone = frame.copy()
    cv.namedWindow('frame')
    cv.resizeWindow('frame', (1920, 1080)) 
    cv.setMouseCallback('frame', image_crop)
    print(frame.shape)#
    print('ROI selecting...')
    while True:
        cv.imshow('frame', frame)
        if cv.waitKey(1) & 0xFF == ord('r'):
            print('ROI selection reset.')
            frame = clone.copy()
        elif cv.waitKey(1) & 0xff == ord('c'):
            print('ROI selected.')
            cap.release()
            break
    if len(refPoints) == 2:
        cv.rectangle(frame, refPoints[0], refPoints[1], (0, 255, 0), 2)
        cv.imshow('frame', frame)
        if cv.waitKey(0) & 0xFF == ord('q'):
            cap.release()
            cv.destroyAllWindows()
    else:
        print('only one ROI allowed.')
        cap.release()
        cv.destroyAllWindows()
        return -1
    # 3 - run inferences on ROI
    t1 = 0
    t2 = 0
    t3 = 0
    i = 0
    cap = cv.VideoCapture(streamer)
    cv.namedWindow('frame')
    cv.resizeWindow('frame', (1920, 1080)) 
    while(True):
        # image reading & cropping
        t0 = time()        
        ret, frame = cap.read()
        if frame is None:
            break
        cropped = frame[refPoints[0][1]:refPoints[1][1], refPoints[0][0]:refPoints[1][0]]
        dt = time()-t0
        t1 += dt
        # image inference
        t0 = time()
        image = cropped.copy()
        boxes, scores, classes = run_inference_for_single_image(image, detection_graph, sess)
        boxes = boxes[scores>0.95]
        classes = classes[scores>0.95]
        boxes1 = boxes[classes==1]
        boxes2 = boxes[classes==2]
        dt = time()-t0
        t2 += dt
        t0 = time()
        cv.rectangle(frame, refPoints[0], refPoints[1], (0, 255, 0), 2)
        if len(boxes1):
            #vis_util.draw_bounding_boxes_on_image_array(image, boxes)
            cv.rectangle(frame, refPoints[0], refPoints[1], (0, 0, 255), 4)
            ser.write(b'1')
        else:
            ser.write(b'0')
        #if len(boxes2):
        #    cv.rectangle(frame, refPoints[0], refPoints[1], (255, 0, 0), 4)
        #print(frame.shape)

        #print(frame.shape)#
        cv.imshow('frame', frame)
        out.write(frame)
        dt = time()-t0
        t3 += dt
        i += 1
        if cv.waitKey(1) & 0xFF == ord('q'):
            break
    print("image reading: average %f sec/frame"%(t1/i))
    print("image processing: average %f sec/frame"%(t2/i))
    print("image showing/saving: average %f sec/frame"%(t3/i))
    ser.close()
    sess.close()
    cap.release()
    out.release()
    cv.destroyAllWindows()

if __name__ == '__main__':
    main()

0 个答案:

没有答案