无法将冻结的图转换为tflite模型

时间:2019-07-10 09:54:12

标签: python python-3.x tensorflow tensorflow-lite

我正在尝试使用提供的tflite_converter将冻结的图转换为tflite模型。我正在重构如何创建.pb文件,以确保在途中不会弄乱某些东西。

1。训练并创建SavedModel

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.contrib import lite

# Create fake training data
xs = np.array([ -1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([ -3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)

# Create model
model = keras.models.Sequential([keras.layers.Dense(units=1, input_shape=[1])])

# Quantization aware training
sess = keras.backend.get_session()
tf.contrib.quantize.create_training_graph(sess.graph)
sess.run(tf.global_variables_initializer())

tf.summary.FileWriter('logs/', graph=sess.graph)
# Compile model
model.compile(optimizer='sgd', 
              loss='mean_squared_error')

# Train model
model.fit(xs, ys, epochs=500, batch_size=2, verbose=2)
print(model.predict([10.0]))
model.summary()

tf.saved_model.simple_save(sess, 
                            './tmp', 
                            inputs={'x': model.input},
                            outputs={t.name: t for t in model.outputs})

2。加载SavedModel并创建可冻结的评估图

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.contrib import lite

export_dir = './tmp'

with tf.Session(graph=tf.Graph()) as sess:

    tf.contrib.quantize.create_eval_graph(sess.graph)
    sess.run(tf.global_variables_initializer())

    tf.saved_model.loader.load(sess, ["serve"], export_dir)

    tf.io.write_graph(sess.graph, '.', 'lin-keras-eval.pb', as_text=False)

3。使用CLI创建冻结的图形

freeze_graph \
--input_graph='lin-keras-eval.pb' \
--input_saved_model_dir='tmp/' \
--output_graph='lin-keras-frozen.pb' \
--output_node_name='dense/BiasAdd' \
--input_binary=True

4。问题:转换为tflite格式

尝试通过以下命令进行转换时,我遇到了错误。

tflite_convert \
--graph_def_file=lin-keras-frozen.pb \ 
--output_file=lin-keras-frozen.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--input_shape=1,1 \
--input_array=dense_input \
--output_array=dense/BiasAdd

这会产生错误(Python 3.6.8,tf版本1.13):

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 426, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node dense/weights_quant/AssignMinLast was passed float from dense/weights_quant/min:0 incompatible with expected float_ref.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/bin/tflite_convert", line 10, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 442, in main
    app.run(main=run_main, argv=sys.argv[:1])
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/app.py", line 125, in run
    _sys.exit(main(argv))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 438, in run_main
    _convert_model(tflite_flags)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 122, in _convert_model
    converter = _get_toco_converter(flags)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 109, in _get_toco_converter
    return converter_fn(**converter_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py", line 274, in from_frozen_graph
    _import_graph_def(graph_def, name="")
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 430, in import_graph_def
    raise ValueError(str(e))
ValueError: Input 0 of node dense/weights_quant/AssignMinLast was passed float from dense/weights_quant/min:0 incompatible with expected float_ref.

有没有办法解决这个问题?我在GitHub上发现了一些问题,但无法转换模型。任何帮助表示赞赏!

0 个答案:

没有答案