我正在尝试使用提供的tflite_converter将冻结的图转换为tflite模型。我正在重构如何创建.pb文件,以确保在途中不会弄乱某些东西。
1。训练并创建SavedModel
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.contrib import lite
# Create fake training data
xs = np.array([ -1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([ -3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)
# Create model
model = keras.models.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
# Quantization aware training
sess = keras.backend.get_session()
tf.contrib.quantize.create_training_graph(sess.graph)
sess.run(tf.global_variables_initializer())
tf.summary.FileWriter('logs/', graph=sess.graph)
# Compile model
model.compile(optimizer='sgd',
loss='mean_squared_error')
# Train model
model.fit(xs, ys, epochs=500, batch_size=2, verbose=2)
print(model.predict([10.0]))
model.summary()
tf.saved_model.simple_save(sess,
'./tmp',
inputs={'x': model.input},
outputs={t.name: t for t in model.outputs})
2。加载SavedModel并创建可冻结的评估图
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.contrib import lite
export_dir = './tmp'
with tf.Session(graph=tf.Graph()) as sess:
tf.contrib.quantize.create_eval_graph(sess.graph)
sess.run(tf.global_variables_initializer())
tf.saved_model.loader.load(sess, ["serve"], export_dir)
tf.io.write_graph(sess.graph, '.', 'lin-keras-eval.pb', as_text=False)
3。使用CLI创建冻结的图形
freeze_graph \
--input_graph='lin-keras-eval.pb' \
--input_saved_model_dir='tmp/' \
--output_graph='lin-keras-frozen.pb' \
--output_node_name='dense/BiasAdd' \
--input_binary=True
4。问题:转换为tflite格式
尝试通过以下命令进行转换时,我遇到了错误。
tflite_convert \
--graph_def_file=lin-keras-frozen.pb \
--output_file=lin-keras-frozen.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--input_shape=1,1 \
--input_array=dense_input \
--output_array=dense/BiasAdd
这会产生错误(Python 3.6.8,tf版本1.13):
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 426, in import_graph_def
graph._c_graph, serialized, options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 0 of node dense/weights_quant/AssignMinLast was passed float from dense/weights_quant/min:0 incompatible with expected float_ref.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/bin/tflite_convert", line 10, in <module>
sys.exit(main())
File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 442, in main
app.run(main=run_main, argv=sys.argv[:1])
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 438, in run_main
_convert_model(tflite_flags)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 122, in _convert_model
converter = _get_toco_converter(flags)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/tflite_convert.py", line 109, in _get_toco_converter
return converter_fn(**converter_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py", line 274, in from_frozen_graph
_import_graph_def(graph_def, name="")
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/importer.py", line 430, in import_graph_def
raise ValueError(str(e))
ValueError: Input 0 of node dense/weights_quant/AssignMinLast was passed float from dense/weights_quant/min:0 incompatible with expected float_ref.
有没有办法解决这个问题?我在GitHub上发现了一些问题,但无法转换模型。任何帮助表示赞赏!