将PB转换为tflite并获得ValueError

时间:2020-04-19 14:27:28

标签: python tensor tf-lite

我正在尝试使用此代码将我的pb转换为tflite。我从github(ImageCaptioning)获得了代码。作者利用此代码来转换他们的模型,我能够制作pb模型,但是在尝试将pb模型转换为tflite时遇到了一些问题。

import tensorflow as tf
from tensorflow.python.platform import gfile
import cv2
import numpy as np


def main():
    sess = tf.Session()
    GRAPH_LOCATION = 'C:/Users/User/Documents/models-master/research/im2txt/im2txt/data/output_graph.pb'
    VOCAB_FILE = 'C:/Users/User/Documents/models-master/research/im2txt/Pretrained-Show-and-Tell-model-master/word_counts.txt'
    IMAGE_FILE = 'C:/Users/User/Documents/models-master/research/im2txt/g3doc/COCO_val2014_000000224477.jpg'

    # Read model
    with gfile.FastGFile(GRAPH_LOCATION, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def)

    with tf.gfile.GFile(IMAGE_FILE, "rb") as f:
        encoded_image = f.read()

    input_names = ['import/image_feed:0', 'import/input_feed:0', 'import/lstm/state_feed:0']
    output_names = ['import/softmax:0', 'import/lstm/state:0', 'import/lstm/initial_state:0']

    g = tf.get_default_graph()
    input_tensors = [g.get_tensor_by_name(x) for x in input_names]  
    output_tensors = [g.get_tensor_by_name(x) for x in output_names]


    converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
    model = converter.convert()
    fid = open("C:/Users/User/Documents/models-master/research/im2txt/im2txt/data/converted_model.tflite", "wb")
    fid.write(model)
    fid.close()


if __name__ == '__main__':
    main()

但我收到此错误:

   "'{0}'.".format(_get_tensor_name(tensor)))
ValueError: Provide an input shape for input array 'import/image_feed'.

我是tfLite的新手,我找不到与代码有关的问题。

1 个答案:

答案 0 :(得分:0)

错误的根本原因是输入数组的input_shape。您需要向转换器提供输入形状。您可以使用tensorboardnetron检查* .pb文件以找到input_shapes。检查示例如下。

import tensorflow as tf
graph_def_file = "./Mymodel.pb"
tflite_file = 'mytflite.tflite'

input_arrays = ["input"]
output_arrays = ["output"]

converter = tf.lite.TFLiteConverter.from_frozen_graph(
   graph_def_file=graph_def_file,
   input_arrays=input_arrays,
   output_arrays=output_arrays,input_shapes={'input_mel':[ 1, 32, 32]})

tflite_model = converter.convert()

open(tflite_file,'wb').write(tflite_model)

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()