我在Tensorflow2.0a中构建了几个自定义图层和模型,以扩展Keras类。在所有这些方法中,我都实现了get_config()
和from_config()
方法。我还在CustomModel类中实现了model.save()
方法(类似于顺序模型)。
def save(self, filepath, overwrite=True, include_optimizer=True, **kwargs):
from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top
save_model(self, filepath, overwrite, include_optimizer)
那些方法允许我使用:
model.save("model.h5")
new_model = keras.models.load_model('model.h5',
custom_objects={
'CustomModel': CustomModel,
'CustomLayer': CustomLayer, #...
})
加载具有8个图层和9000个参数的模型大约需要30秒。
但是我也可以这样加载模型:
model.save_weights("weights.h5")
config = model.get_config()
new_model = CustomModel.from_config(config)
new_model.load_weights("weights.h5")
对于同一模型,此方法需要0.6秒。
第一种方法如此缓慢有没有原因?是由于自定义模型/图层的反序列化吗?
在这种情况下,首选方法是什么?
我知道TF2是相当新的东西,但是任何见识都会受到赞赏!