如何从save_model创建tflite文件(SSD MobileNet)

时间:2019-05-15 13:24:37

标签: python tensorflow object-detection tensorflow-lite

我想基于经过重新训练的ssd_mobilenet模型(类似于youtube上的那个家伙)创建对象检测应用。

我从Tensorflow Model Zoo中选择了模型ssd_mobilenet_v2_coco。在重新训练过程之后,我得到了具有以下结构的模型:

- saved_model
    - variables (empty folder)
    - saved_model.pb
- checkpoint
- frozen_inverence_graph.pb
- model.ckpt.data-00000-of-00001
- model.ckpt.index
- model.ckpt.meta
- pipeline.config

在同一文件夹中,我有带有以下代码的python脚本:

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model("saved_model")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

运行此代码后,出现以下错误:

ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.

似乎模型中缺少图像宽度和高度。当我使用youtube视频中的模型时,它就可以正常工作。

经过大量研究和尝试,我尝试了其他方法,例如运行bazel / toco,但没有任何帮助我创建tflite文件。

1 个答案:

答案 0 :(得分:0)

documentation中所述,您可以在tf.lite.TFLiteConverter.from_saved_model中传递不同的参数。

  

对于更复杂的SavedModels,可以传递到TFLiteConverter.from_saved_model()的可选参数是input_arrays, input_shapes, output_arrays, tag_set and signature_key。通过运行help(tf.lite.TFLiteConverter),可以获得每个参数的详细信息。

您可以按照here的说明传递此信息。您需要提供输入张量名称及其形状,还需要输出张量名称及其形状。对于ssd_mobilenet_v2_coco,您需要定义使用网络所需的输入形状,

tf.lite.TFLiteConverter.from_saved_model("saved_model", input_shapes={("image_tensor" : [1,300,300,3])})