从tensorflow冻结图转换为tflite以进行Android推理

时间:2020-02-24 01:13:49

标签: tensorflow machine-learning computer-vision onnx tf-lite

我正在尝试从pytorch转换为tflite,以便针对我正在使用的应用程序进行android推断,该应用程序使用篮球的实时摄像头数据来创建已拍摄和未拍摄照片的热图。它已经适用于iOS。 Here's a demo

我设法从pytorch(.pth)转换为onnx,从onnx转换为tensorflow冻结图(.pb)。该tf冻结图推断出结果。

但是,当我尝试从冻结的图形转换为tflite时,出现以下错误:

RuntimeError: Inputs and outputs not all float|uint8|int16 types.Node number 2 (ADD) failed to invoke.

解释器的输入详细信息[interpreter.get_input_details(),interpreter.get_output_details()]建议数据类型为numpy.float32,这是我感到困惑的地方。那不应该算作浮动吗?任何建议/帮助将不胜感激!

[{'name': 'image', 'index': 21904, 'shape': array([  3, width, height], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'action', 'index': 7204, 'shape': array([], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]

来自tensorflow的文档:

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

需要帮助的人的tensorflow冻结图(1.14.0)推断文件:

import numpy as np
import tensorflow as tf
from PIL import Image

w = ...
h = ...

class CNN(object):

    def __init__(self, model_filepath):
        self.model_filepath = model_filepath
        self.load_graph(model_filepath = self.model_filepath)

    def load_graph(self, model_filepath):
        self.graph = tf.Graph()

        with tf.gfile.GFile(model_filepath, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        with self.graph.as_default():
            self.input = tf.placeholder(np.float32, shape = [3, h, w], name='image')
            tf.import_graph_def(graph_def, {'image': self.input})

        self.graph.finalize()
        self.sess = tf.Session(graph = self.graph)

    def test(self, data):
        output_tensor = self.graph.get_tensor_by_name('import/action:0')
        output = self.sess.run(output_tensor, feed_dict = {self.input: data})
        return output

def main():
    nn = CNN(model_filepath='out_1.14.pb')
    img = np.asarray(Image.open('example.jpg')).astype(np.float32)
    img = img.transpose(-1, 0, 1)
    ans = nn.test(data=img)
    print(ans)


if __name__ == '__main__':
    main()

0 个答案:

没有答案