我有一个pb格式的权重文件,然后将其转换为tflite格式以进行移动设备部署。但是,我的模型有两个输入,一个用于图像(尺寸:1 * 3 * 36 * 60),另一个用于矢量(尺寸:1 * 2)。当我验证tflite格式模型时,我的代码如下所示:
import numpy as np
import tensorflow as tf
from keras.preprocessing import image
img = image.load_img('f01_70_-0.5890_-0.3927.png', target_size=(36, 60))
head_angle = np.array([[-0.3619980211517256, -0.44335020008101705]])
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="pupilModel.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
#print(input_details)
interpreter.set_tensor(input_details[0]['index'], x)
interpreter.set_tensor(input_details[1]['index'], head_angle)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
输出日志如下所示:
[{'name': 'inputs/input_img', 'index': 22, 'shape': array([ 1, 36, 60, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)},
{'name': 'inputs/head_angle', 'index': 21, 'shape': array([1, 2], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
File "/Users/jackie/Downloads/pupil_model/tf_to_lite.py", line 22, in <module>
interpreter.set_tensor(input_details[1]['index'], head_angle)
File "/Users/jackie/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/lite/python/interpreter.py", line 156, in set_tensor
self._interpreter.SetTensor(tensor_index, value)
File "/Users/jackie/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/lite/python/interpreter_wrapper/tensorflow_wrap_interpreter_wrapper.py", line 133, in SetTensor
return _tensorflow_wrap_interpreter_wrapper.InterpreterWrapper_SetTensor(self, i, value)
ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 21
我的问题是如何通过两个输出验证tflite模型?
答案 0 :(得分:0)
head_angle
np.array
没有TFLite要求的int32
类型。
尝试以下更改:
head_angle = np.array([[-0.3619980211517256, -0.44335020008101705]], dtype=np.int32)