如何在Java中从Tensorflow模型读取输出

时间:2019-02-15 19:49:48

标签: tensorflow tensorflow-lite

我尝试将TensorflowLite与ssdlite_mobilenet_v2_coco模型(从https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md转换为tflite文件)一起使用,以从我的Android应用程序(java)中的摄像头流中检测对象。我执行

    interpreter.run(input, output);

其中输入是转换为ByteBuffer的图像,输出是浮点数组-大小[1] [10] [4]以匹配张量。

如何将此float数组转换为一些可读的输出? -例如获取边界框的坐标,对象名称,概率。

1 个答案:

答案 0 :(得分:0)

好吧,我知道了。 首先,我在python中运行以下命令:

>>> import tensorflow as tf
>>> interpreter = tf.contrib.lite.Interpreter("detect.tflite")

然后加载Tflite模型:

>>> interpreter.allocate_tensors()
>>> input_details = interpreter.get_input_details()
>>> output_details = interpreter.get_output_details()

现在,我详细了解了输入和输出的外观如何

>>> input_details
[{'name': 'normalized_input_image_tensor', 'index': 308, 'shape': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

因此输入是转换后的图像-形状为300 x 300

>>> output_details
[{'name': 'TFLite_Detection_PostProcess', 'index': 300, 'shape': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 301, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 302, 'shape': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 303, 'shape': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

现在我已经有了该模型中多个输出的说明。 我需要改变

interpreter.run(input, output) 

interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);

其中“输入”为:

private Object[1] inputs;
inputs[0] = imgData; //imgData - image converted to bytebuffer 

map_of_indices_to_outputs为:

private Map<Integer, Object> output_map = new TreeMap<>();
private float[1][10][4] boxes;
private float[1][10] scores;
private float[1][10] classes;
output_map.put(0, boxes);
output_map.put(1, classes);
output_map.put(2, scores);

现在运行后,我在框中有10个对象的坐标,在类中对象的索引(在coco标签文件中),必须加1才能获得正确的键!和分数的概率。

希望这对以后的人有帮助。