将save_model.pb转换为model.tflite

时间:2020-08-29 23:26:55

标签: python tensorflow tensorflow2.0 tensorflow-lite tf-lite

Tensorflow版本:2.2.0

操作系统:Windows 10

我正在尝试将saved_model.pb转换为tflite文件。

这是我正在运行的代码:

import tensorflow as tf

# Convert
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='C:\Data\TFOD\models\ssd_mobilenet_v2_quantized')
tflite_model = converter.convert()


fo = open("model.tflite", "wb")
fo.write(tflite_model)
fo.close

此代码在转换时给出错误:

  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\convert.py", line 196, in toco_convert_protos
    model_str = wrap_toco.wrapped_toco_convert(model_flags_str,
  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\wrap_toco.py", line 32, in wrapped_toco_convert
    return _pywrap_toco_api.TocoConvert(
Exception: <unknown>:0: error: loc("Func/StatefulPartitionedCall/input/_0"): requires all operands and results to have compatible element types
<unknown>:0: note: loc("Func/StatefulPartitionedCall/input/_0"): see current operation: %1 = "tf.Identity"(%arg0) {device = ""} : (tensor<1x?x?x3x!tf.quint8>) -> tensor<1x?x?x3xui8>


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "c:/Data/TFOD/convert.py", line 13, in <module>
    tflite_model = converter.convert()
  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\lite.py", line 1076, in convert
    return super(TFLiteConverterV2, self).convert()
  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\lite.py", line 899, in convert
    return super(TFLiteFrozenGraphConverterV2,
  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\lite.py", line 629, in convert
    result = _toco_convert_impl(
  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\convert.py", line 569, in toco_convert_impl
    data = toco_convert_protos(
  File "C:\Users\Mr.Ace\AppData\Roaming\Python\Python38\site-packages\tensorflow\lite\python\convert.py", line 202, in toco_convert_protos
    raise ConverterError(str(e))
tensorflow.lite.python.convert.ConverterError: <unknown>:0: error: loc("Func/StatefulPartitionedCall/input/_0"): requires all operands and results to have compatible element types
<unknown>:0: note: loc("Func/StatefulPartitionedCall/input/_0"): see current operation: %1 = "tf.Identity"(%arg0) {device = ""} : (tensor<1x?x?x3x!tf.quint8>) -> tensor<1x?x?x3xui8>

2 个答案:

答案 0 :(得分:1)

Tensorflow在export_tflite_ssd_graph.py文件夹中提供了一个名为model/object_detection的python文件,可用于将保存的模型转换为tflite格式。

This是指向该文件的GitHub链接。当您下载models目录时,将下载该文件。

使用方式:

python object_detection/export_tflite_ssd_graph.py \
    --pipeline_config_path path/to/ssd_mobilenet.config \
    --trained_checkpoint_prefix path/to/model.ckpt \
    --output_directory path/to/exported_model_directory

预期输出将在目录中
path / to / exported_model_directory(如果不存在则创建)
内容:

  • tflite_graph.pbtxt
  • tflite_graph.pb

有关完整用法,您可以阅读文件中的注释。

答案 1 :(得分:0)

好,我终于解决了!

我所做的是每晚使用tf并使用以下Python脚本:

import tensorflow as tf

saved_model_dir = "C:/Data/TFOD/models/ssd_mobilenet_v2_quantized/tflite"
converter = tf.lite.TFLiteConverter.from_saved_model(
    saved_model_dir, signature_keys=['serving_default'])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

fo = open(
    "C:/Data/TFOD/models/ssd_mobilenet_v2_quantized/tflite/model.tflite", "wb")
fo.write(tflite_model)
fo.close

这可以解决问题,您可以转换为.tflite