我正在尝试使用Keras和tf中提供的Flower DB在移动网络上进行转移学习。训练效果很好,我可以进行预测,但是无法保存模型以备将来使用,因此每次都必须进行训练。
我正在使用python 3.7和Keras的最新更新开发MACOS Mojave,已经执行了pip --upgrade keras。
feature_extractor_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2" #@param {type:"string"}
# Create the module, and check the expected image size:
def feature_extractor(x):
feature_extractor_module = hub.Module(feature_extractor_url)
return feature_extractor_module(x)
IMAGE_SIZE = hub.get_expected_image_size(hub.Module(feature_extractor_url))
# Ensure the data generator is generating images of the expected size:
image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SIZE, subset='training')
test_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SIZE, subset='validation')
# Wrap the module in a keras layer.
features_extractor_layer = layers.Lambda(feature_extractor, input_shape=IMAGE_SIZE+[3])
# Freeze the variables in the feature extractor layer, so that the training only modifies the new classifier layer.
features_extractor_layer.trainable = False
# Attach a classification head
# Now wrap the hub layer in a tf.keras.Sequential model, and add a new classification layer.
model = tf.keras.Sequential([
features_extractor_layer,
layers.Dense(image_data.num_classes, activation='softmax')
])
# Initialize the TFHub module.
import tensorflow.keras.backend as K
sess = K.get_session()
init = tf.global_variables_initializer()
sess.run(init)
# Train the model
model.compile(
optimizer= tf.keras.optimizers.Adam(),
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit((item for item in image_data), epochs=1,
steps_per_epoch=steps_per_epoch,
callbacks = [batch_stats])
model.save("./saved_models/flower_model.h5")
del model
from keras.models import load_model
model = load_model('./saved_models/flower_model.h5')
model.summary()
model.save()似乎工作正常,在该目录中正确创建了文件flower_model.h5,但是在运行load_model()时,输出错误为:
Using TensorFlow backend.
Traceback (most recent call last):
File "/Users/david/Library/Preferences/PyCharmCE2018.2/scratches/retrain_flowers.py", line 114, in <module>
model = load_model('./saved_models/flower_model.h5')
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/engine/saving.py", line 419, in load_model
model = _deserialize_model(f, custom_objects, compile)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/engine/saving.py", line 225, in _deserialize_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/engine/saving.py", line 458, in model_from_config
return deserialize(config, custom_objects=custom_objects)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
list(custom_objects.items())))
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/engine/sequential.py", line 300, in from_config
custom_objects=custom_objects)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/utils/generic_utils.py", line 145, in deserialize_keras_object
list(custom_objects.items())))
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/layers/core.py", line 764, in from_config
return cls(**config)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/layers/core.py", line 626, in __init__
super(Lambda, self).__init__(**kwargs)
File "/Users/david/Library/Python/3.7/lib/python/site-packages/keras/engine/base_layer.py", line 128, in __init__
raise TypeError('Keyword argument not understood:', kwarg)
TypeError: ('Keyword argument not understood:', 'module')
我想在另一个脚本中加载模型,以便执行检测,但不确定在加载模型之前是否应该初始化某些内容。感谢您抽出宝贵的时间来阅读我的问题。