如何从TFLite对象检测Python获取有用的数据

时间:2019-12-02 17:08:51

标签: python tensorflow raspberry-pi object-detection-api tf-lite

我有一个树莓派4,我想以良好的帧频进行对象检测。我尝试了tensorflow和YOLO,但两者都以1 fps的速度运行。所以我正在尝试TensorFlow Lite。我已经下载了tflite文件和labelmap.txt文件。我使用this link来尝试进行推理。在这里我遇到了一个问题。我不明白如何从输出中获取结果(分类,边界框的余弦和conf)。

这是我的代码:

import tensorflow as tf 
import numpy as np
import cv2

interpreter = tf.lite.Interpreter(model_path="/content/drive/My Drive/detect.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
print()

input_shape = input_details[0]['shape']
im = cv2.imread("/content/drive/My Drive/doggy.jpg")
im_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_rgb = cv2.resize(im_rgb, (input_shape[1], input_shape[2]))
input_data = np.expand_dims(im_rgb, axis=0)
print(input_data.shape)
print()

interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
print()
print(output_data)

这是我的输出:

[{'name': 'normalized_input_image_tensor', 'index': 175, 'shape': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128)}]
[{'name': 'TFLite_Detection_PostProcess', 'index': 167, 'shape': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 168, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 169, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 170, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

(1, 300, 300, 3)

(1, 10, 4)

[[[ 1.66415479e-02  5.48024022e-04  8.67791831e-01  3.35325867e-01]
  [ 7.41335377e-02  3.22245747e-01  9.64617252e-01  9.71388936e-01]
  [-2.11861148e-03  5.41743517e-01  2.60241032e-01  7.02846169e-01]
  [-5.67546487e-03  3.26282382e-01  8.59034657e-01  6.30770981e-01]
  [ 7.27111334e-03  7.90268779e-01  2.86753297e-01  9.56545353e-01]
  [ 2.07318692e-03  7.96441555e-01  5.48386931e-01  9.96111989e-01]
  [-1.04907183e-02  2.38761827e-01  6.75976276e-01  7.01156497e-01]
  [ 3.12007014e-02  1.34294275e-02  5.82291842e-01  3.10949832e-01]
  [-1.95578858e-03  7.05318868e-01  9.18281525e-02  7.96184599e-01]
  [-5.43205580e-03  3.23292404e-01  6.34427786e-01  5.68508685e-01]]]

输出(最后一个列表)似乎是一个非常小的数字数组,如何从中得到结果?

谢谢

1 个答案:

答案 0 :(得分:0)

我在github上的@daverim的帮助下解决了这个问题,在那里我打开了一个问题。 https://github.com/tensorflow/tensorflow/issues/34761。这是获取有用数据的代码:

detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])
print(num_boxes)
for i in range(int(num_boxes[0])):
  if detection_scores[0, i] > .5:
       class_id = detection_classes[0, i]
       print(class_id)

使用labelmap.txt文件,我们可以获得类名。