我正在尝试使用此代码将我的pb转换为tflite。我从github(ImageCaptioning)获得了代码。作者利用此代码来转换他们的模型,我能够制作pb模型,但是在尝试将pb模型转换为tflite时遇到了一些问题。
import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2
import numpy as np
def main():
sess = tf.Session()
GRAPH_LOCATION = 'C:/Users/User/Documents/models-master/research/im2txt/im2txt/data/output_graph.pb'
VOCAB_FILE = 'C:/Users/User/Documents/models-master/research/im2txt/Pretrained-Show-and-Tell-model-master/word_counts.txt'
IMAGE_FILE = 'C:/Users/User/Documents/models-master/research/im2txt/g3doc/COCO_val2014_000000224477.jpg'
# Read model
with gfile.FastGFile(GRAPH_LOCATION, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def)
with tf.gfile.GFile(IMAGE_FILE, "rb") as f:
encoded_image = f.read()
input_names = ['import/image_feed:0', 'import/input_feed:0', 'import/lstm/state_feed:0']
output_names = ['import/softmax:0', 'import/lstm/state:0', 'import/lstm/initial_state:0']
g = tf.get_default_graph()
input_tensors = [g.get_tensor_by_name(x) for x in input_names]
output_tensors = [g.get_tensor_by_name(x) for x in output_names]
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
model = converter.convert()
fid = open("C:/Users/User/Documents/models-master/research/im2txt/im2txt/data/converted_model.tflite", "wb")
fid.write(model)
fid.close()
if __name__ == '__main__':
main()
但我收到此错误:
"'{0}'.".format(_get_tensor_name(tensor)))
ValueError: Provide an input shape for input array 'import/image_feed'.
我是tfLite的新手,我找不到与代码有关的问题。
答案 0 :(得分:0)
错误的根本原因是输入数组的input_shape
。您需要向转换器提供输入形状。您可以使用tensorboard或netron检查* .pb文件以找到input_shapes
。检查示例如下。
import tensorflow as tf
graph_def_file = "./Mymodel.pb"
tflite_file = 'mytflite.tflite'
input_arrays = ["input"]
output_arrays = ["output"]
converter = tf.lite.TFLiteConverter.from_frozen_graph(
graph_def_file=graph_def_file,
input_arrays=input_arrays,
output_arrays=output_arrays,input_shapes={'input_mel':[ 1, 32, 32]})
tflite_model = converter.convert()
open(tflite_file,'wb').write(tflite_model)
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()