我正在使用TensorFlow估计器来训练和保存模型,然后将其转换为.tflite。我将模型保存如下:
feat_cols = [tf.feature_column.numeric_column('feature1'),
tf.feature_column.numeric_column('feature2'),
tf.feature_column.numeric_column('feature3'),
tf.feature_column.numeric_column('feature4')]
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
feature_spec = tf.feature_column.make_parse_example_spec(feat_cols)
default_batch_size = 1
serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='tf_example')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
dnn_regressor.export_saved_model(export_dir_base='model',
serving_input_receiver_fn=serving_input_receiver_fn)
当我尝试使用以下方法转换生成的.pb文件时:
tflite_convert --output_file=/tmp/foo.tflite --saved_model_dir=/tmp/saved_model
我得到一个例外,说TensorFlow Lite不支持ParseExample操作。
标准TensorFlow Lite运行时不支持模型中的某些运算符。如果这些是本机TensorFlow运算符,则可以通过传递--enable_select_tf_ops或通过在调用tf.lite.TFLiteConverter()时设置target_ops = TFLITE_BUILTINS,SELECT_TF_OPS来使用扩展的运行时。否则,如果您有针对它们的自定义实现,则可以使用--allow_custom_ops或通过在调用tf.lite.TFLiteConverter()时设置allow_custom_ops = True来禁用此错误。这是您正在使用的内置运算符的列表:CONCATENATION,FULLY_CONNECTED,RESHAPE。这是需要自定义实现的运算符的列表:ParseExample。
如果我尝试导出而不进行序列化的模型,则当我尝试对生成的.pb文件进行预测时,该函数期望并清空set(),而不是我传递的输入的字典
ValueError:在input_dict中获得了意外的键:{'feature1','feature2','feature3','feature4'} 预期:set()
我在做什么错?这是无需进行任何序列化就可以保存模型的代码
features = {
'feature1': tf.placeholder(dtype=tf.float32, shape=[1], name='feature1'),
'feature2': tf.placeholder(dtype=tf.float32, shape=[1], name='feature2'),
'feature3': tf.placeholder(dtype=tf.float32, shape=[1], name='feature3'),
'feature4': tf.placeholder(dtype=tf.float32, shape=[1], name='feature4')
}
def serving_input_receiver_fn():
return tf.estimator.export.ServingInputReceiver(features, features)
dnn_regressor.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_receiver_fn, as_text=True)
答案 0 :(得分:0)
已解决
使用build_raw_serving_input_receiver_fn我设法导出保存的模型而没有任何序列化:
serve_input_fun = tf.estimator.export.build_raw_serving_input_receiver_fn(
features,
default_batch_size=None
)
dnn_regressor.export_savedmodel(
export_dir_base="model",
serving_input_receiver_fn=serve_input_fun,
as_text=True
)
注意:在进行预测时,预测器不知道默认的signature_def,因此我需要指定它:
predict_fn = predictor.from_saved_model("model/155482...", signature_def_key="predict")
也从.pb转换为.tflite我使用了Python API,因为我还需要在这里指定signature_def:
converter = tf.contrib.lite.TFLiteConverter.from_saved_model('model/155482....', signature_key='predict')