我训练了一个模型,现在尝试在 tflite 中导出它,但是出现一个奇怪的错误:
import tensorflow as tf
graph_def_file = "./model.pb"
input_arrays = ["question1_embedding", "question2_embedding", "is_training"]
output_arrays = ["prediction"]
converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
我收到以下错误:
Traceback (most recent call last):
File "export_tflite.py", line 13, in <module>
tflite_model = converter.convert()
File "/home/ben/.pyenv/versions/clevy/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py", line 411, in convert
self._set_batch_size(batch_size=1)
File "/home/ben/.pyenv/versions/clevy/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py", line 501, in _set_batch_size
shape = tensor.get_shape().as_list()
File "/home/ben/.pyenv/versions/clevy/lib/python3.6/site-packages/tensorflow/python/framework/tensor_shape.py", line 904, in as_list
raise ValueError("as_list() is not defined on an unknown TensorShape.")
ValueError: as_list() is not defined on an unknown TensorShape.
我不确定错误来自何处。该模型可以进行推理,我可以从同一 model.pb 文件中将其还原。
我正在使用tensorflow v1.12