使用TF 2.0将save_model转换为TFLite模型

时间:2020-01-10 10:15:02

标签: python tensorflow tensorflow2.0 tensorflow-lite

目前,我正在将自定义对象检测模型(使用SSD和Inception网络进行训练)转换为量化的TFLite模型。我可以使用以下代码片段(使用 Tensorflow 1.4 )将自定义对象检测模型从冻结的图形转换为量化的TFLite模型:

converter = tf.lite.TFLiteConverter.from_frozen_graph(args["model"],input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True
converter.post_training_quantize=True 
tflite_model = converter.convert()
open(args["output"], "wb").write(tflite_model)

但是{strong> Tensorflow 2.0 (refer this link)不能使用tf.lite.TFLiteConverter.from_frozen_graph类方法。因此,我尝试使用tf.lite.TFLiteConverter.from_saved_model类方法转换模型。代码片段如下所示:

converter = tf.lite.TFLiteConverter.from_saved_model("/content/") # Path to saved_model directory
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

上面的代码段引发以下错误:

ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.

我试图将input_shapes作为参数

converter = tf.lite.TFLiteConverter.from_saved_model("/content/",input_shapes={"image_tensor" : [1,300,300,3]})

但是会引发以下错误:

TypeError: from_saved_model() got an unexpected keyword argument 'input_shapes'

我错过了什么吗?请随时纠正我!

1 个答案:

答案 0 :(得分:3)

我使用tf.compat.v1.lite.TFLiteConverter.from_frozen_graph得到了解决方案。该compat.v1TF1.x的功能引入TF2.x中。 以下是完整的代码:

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph("/content/tflite_graph.pb",input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
    input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
    'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True

# Convert the model to quantized TFLite model.
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()


# Write a model using the following line
open("/content/uno_mobilenetV2.tflite", "wb").write(tflite_model)