在张量流会话中处理多个文件

时间:2020-06-03 17:55:32

标签: python tensorflow flask

func_name(loc, id , mn):    
    with detection_graph.as_default():
         with tf.compat.v1.Session(graph=detection_graph) as sess:
                #tf.initialize_all_variables().run()

                while cap.isOpened():
                    ret, image_np = cap.read()
                    print(ret)

                    if not ret:
                        break
                    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
                    image_np_expanded = np.expand_dims(image_np, axis=0)
                    # Extract image tensor
        sess.close()

我使用发送文件 func_name(location, id, model_name)到上面的普通对象检测会话代码进行处理,然后保存并返回,但是在我尝试发送另一个文件而不退出程序后,我得到了第一帧,然后什么也没有发生,即处理没有发生处理完第一个文件后,所有文件中的第一个文件。

如何处理多个文件而不退出代码并重新启动? 我尝试了initialize variablessess.close(),但仍然无法正常工作。 使用flask上传了多个文件。

UPDATE 1

detect_func()是从另一个脚本中调用的,而脚本在其中获取了所需的所有参数。

import numpy as np
import os

import six.moves.urllib as urllib
import sys
sys.path.append("..")
import tarfile
import tensorflow as tf
import zipfile
import cv2

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
# from models.research import *
#from models.research.object_detection.utils import label_map_util
from codes.models.research.object_detection.utils import visualization_utils as vis_util
from codes.models.research.object_detection.utils import label_map_util


#cap = cv2.VideoCapture(0)  # Change only if you have more than one webcams

# What model to download.
# Models can bee found here: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
MODEL_NAME = 'ssd_inception_v2_coco_2017_11_17'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

# Number of classes to detect
NUM_CLASSES = 90

# Download Model
if not os.path.exists(os.path.join(os.getcwd(), MODEL_FILE)):
    print("Downloading model")
    opener = urllib.request.URLopener()
    opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
    tar_file = tarfile.open(MODEL_FILE)
    for file in tar_file.getmembers():
        file_name = os.path.basename(file.name)
        if 'frozen_inference_graph.pb' in file_name:
            tar_file.extract(file, os.getcwd())


# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.io.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')


# Loading label map
# Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`.  Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
    label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)


# Helper code
def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
        (im_height, im_width, 3)).astype(np.uint8)



sess = tf.compat.v1.Session(graph=detection_graph)


def dectect_func(location, id, model_name):
    VID_SAVE_PATH = '/tensorflow/downloads/'
    # Define the video stream
    cap = cv2.VideoCapture(location)  # Change only if you have more than one webcams
    fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
    out = cv2.VideoWriter(VID_SAVE_PATH + id + '.avi',fourcc, 20.0, (640,480))
    while True:
        # Read frame from camera
        ret, image_np = cap.read()
        if not ret:
            break
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        '''
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)
        '''
        print(num_detections)

        # Display output
        cv2.imshow('object detection', cv2.resize(image_np, (800, 600)))

        if cv2.waitKey(25) & 0xFF == ord('q'):
            print("pressed q on window")
            cv2.destroyAllWindows()
            break

    cap.release()
    cv2.destroyAllWindows()




# Detection

更新2:

def process_video():

conn = sqlite3.connect(
    'db/abc.sqlite')
cur = conn.cursor()
cur.execute(
    "SELECT id, location, model_name FROM uploads WHERE isProcessed=0 order by datetime DESC")

id, location, model_name = cur.fetchone()
print(id, location, model_name)
if not (id, location):
    cur.execute(
    "SELECT id, location FROM uploads WHERE isProcessed=0 order by datetime DESC")
func_name(location, id, model_name)

cur.execute("UPDATE uploads SET isProcessed=1  WHERE id='"+id+"'")
conn.commit()
conn.close()
print('yes')

update 3

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
False
yes
File saved successfully
9da51fde-5deb-4f78-8f58-13661723daf8 uploads/output.mp4 ssd_inception_v2_coco_2017_11_17
/tensorflow/ssd_inception_v2_coco_2017_11_17/frozen_inference_graph.pb
True

在这里,无论是否获取帧,我都输出True,最后一个True是我传递的第二个文件,您可以看到其中的位置和内容。它只需要第一帧,什么也没发生。

2 个答案:

答案 0 :(得分:0)

以下修改对我有用,并允许重新使用检测循环:


sess = tf.compat.v1.Session(graph=detection_graph)


def dectect_func(cap):
    while True:
        # Read frame from camera
        ret, image_np = cap.read()
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        '''
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)
        '''
        print(num_detections)

        # Display output
        cv2.imshow('object detection', cv2.resize(image_np, (800, 600)))

        if cv2.waitKey(25) & 0xFF == ord('q'):
            print("pressed q on window")
            cv2.destroyAllWindows()
            break


dectect_func(cap)
dectect_func(cap)

我没有克隆tf object_detection仓库,所以这里没有可视化。但是我看到num_detections会在旋转相机时发生变化。

编辑:我认为opencv保存文件存在问题。尝试以下代码:

def dectect_func(location, id):
    print('processing: ', location, id)
    VID_SAVE_PATH = 'out'
    # Define the video stream
    cap = cv2.VideoCapture(location)  # Change only if you have more than one webcams
    fourcc = cv2.VideoWriter_fourcc('M','J','P','G')
    out = cv2.VideoWriter(VID_SAVE_PATH + id + '.avi', fourcc, 20.0, (640,480)) #cv2.VideoWriter(VID_SAVE_PATH + id + '.avi',fourcc, 20.0, (640,480))
    while True:
        # Read frame from camera
        ret, image_np = cap.read()
        if not ret:
            break
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.

        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)

        print(num_detections)

        # otherwise there will be no file saved if resolution mismatch
        frame = cv2.resize(image_np, (640,480), cv2.INTER_CUBIC)

        out.write(frame)



    cap.release()
    out.release()
    cv2.destroyAllWindows()



# Detection
dectect_func('small.mp4','0')
dectect_func('small.mp4','1')

答案 1 :(得分:0)

某些部分仍然是那些您无法进入调试调用的部分。 对此的更新是使用subprocess.run()的一种解决方法。

相关问题