下面的环境中,我有一个训练有素的keras模型(model.h5)
我尝试使用以下代码在安装了tensorflow版本1.14.1的tensorflow lite(tf-nightly)中使用以下代码对其进行训练后量化
import tensorflow as tf
converter =tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()
但收到以下错误消息:
ValueError: ('Unrecognized keyword arguments:', dict_keys(['input_dtype']))
我的完整代码和追溯:
import tensorflow as tf
keras_file="deep_model.h5"
converter = tf.lite.TFLiteConverter.from_keras_model_file( keras_file )
converter.optimizations= [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
file = open( 'model.tflite' , 'wb' )
file.write( model )
--------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-17-bc011fa77854> in <module>()
4
5
----> 6 converter = tf.lite.TFLiteConverter.from_keras_model_file( keras_file )
7 converter.optimizations= [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
8 tflite_model = converter.convert()
/usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/lite.py in from_keras_model_file(cls, model_file, input_arrays, input_shapes, output_arrays)
625 _keras.backend.clear_session()
626 _keras.backend.set_learning_phase(False)
--> 627 keras_model = _keras.models.load_model(model_file)
628 sess = _keras.backend.get_session()
629
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model(filepath, custom_objects, compile)
213 model_config = json.loads(model_config.decode('utf-8'))
214 model = model_config_lib.model_from_config(model_config,
--> 215 custom_objects=custom_objects)
216
217 # set weights
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
53 '`Sequential.from_config(config)`?')
54 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
---> 55 return deserialize(config, custom_objects=custom_objects)
56
57
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
93 module_objects=globs,
94 custom_objects=custom_objects,
---> 95 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
190 custom_objects=dict(
191 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 192 list(custom_objects.items())))
193 with CustomObjectScope(custom_objects):
194 return cls.from_config(cls_config)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in from_config(cls, config, custom_objects)
1229 # First, we create all layers and enqueue nodes to be processed
1230 for layer_data in config['layers']:
-> 1231 process_layer(layer_data)
1232 # Then we process nodes in order of layer depth.
1233 # Nodes that cannot yet be processed (if the inbound node
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in process_layer(layer_data)
1213 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
1214
-> 1215 layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1216 created_layers[layer_name] = layer
1217
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
93 module_objects=globs,
94 custom_objects=custom_objects,
---> 95 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
192 list(custom_objects.items())))
193 with CustomObjectScope(custom_objects):
--> 194 return cls.from_config(cls_config)
195 else:
196 # Then `cls` may be a function returning a class.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
438 A layer instance.
439 """
--> 440 return cls(**config)
441
442 def compute_output_shape(self, input_shape):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/input_layer.py in __init__(self, input_shape, batch_size, dtype, input_tensor, sparse, name, **kwargs)
67 input_shape = batch_input_shape[1:]
68 if kwargs:
---> 69 raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
70
71 if not name:
ValueError: ('Unrecognized keyword arguments:', dict_keys(['input_dtype']))
那么lite不支持tensorflow旧模型吗? 如何解决此问题并从该模型获得量化权重?
答案 0 :(得分:0)
对于Keras HDF5型号,请使用from_keras_model_file
。在最近的每晚中,已添加了对custom_objects
的支持。
从documentation复制:
# Save tf.keras model in HDF5 format.
keras_file = "keras_model.h5"
tf.keras.models.save_model(model, keras_file)
# Convert to TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)