我正在研究一个涉及检测人体的机器人项目,我正在使用张量流和预定义数据集来创建训练模型。由于我是机器学习的新手,我无法正确获取分类器的输出。我只需要人物检测,并希望避免检测球,笔记本电脑或其他物体。 现在我的摄像头检测到所有物体,如球,蝙蝠,笔记本电脑,电视等。我需要的输出只是得分为80%及以上的人。
我用于创建模型的代码是
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
if not os.path.exists(MODEL_NAME + '/frozen_inference_graph.pb'):
print ('Downloading the model')
opener = urllib.request.URLopener()
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
print ('Download complete')
else:
print ('Model already exists')
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
import cv2
cap = cv2.VideoCapture(1)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
ret = True
while (ret):
ret,image_np = cap.read()
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=8)
cv2.imshow('image',cv2.resize(image_np,(1280,960)))
if cv2.waitKey(27) & 0xFF == ord('q'):
cv2.destroyAllWindows()
cap.release()
break
有人可以解释我如何只能检测到准确度分数大于80%的人。
答案 0 :(得分:6)
正如我从文档here中看到的那样,您只需要检查人员类。现在if
检查所有类。您必须仅为人员类添加
item {
name: "/m/01g317"
id: 1
display_name: "person"
}
条件。下面给出了适当的标识符(取自文档)。
$customerCollection = Mage::getModel('customer/customer')->getCollection()
->addAttributeToSelect('*')
->getSelect()->join( array('varchar'=> 'customer_entity_varchar'), 'varchar.entity_id = main_table.entity_id', array('varchar.attribute_id'))
->addAttributeToFilter('entity_id', array('eq' => '5'))
->addAttributeToFilter('attribute_id', array('eq' => '139'));
答案 1 :(得分:0)
标识符可以在数据文件夹中找到,此任务有90种不同的标识符。创建一个新的文本文件,说' new.txt'现在只需复制您需要显示的标识符,表示您需要显示人员, 复制
item {
name : "/m/01g317"
id : 1
display_name : "Person"
}
然后在最终的代码中将类的数量从90改为1
NUM_CLASSES = 1