Tensorflow2.0a:为什么keras.models.load_model()比model.from_config()慢?

时间:2019-05-17 17:05:48

标签: python keras tensorflow2.0

我在Tensorflow2.0a中构建了几个自定义图层和模型,以扩展Keras类。在所有这些方法中,我都实现了get_config()from_config()方法。我还在CustomModel类中实现了model.save()方法(类似于顺序模型)。

def save(self, filepath, overwrite=True, include_optimizer=True, **kwargs):
        from tensorflow.python.keras.models import save_model  # pylint: disable=g-import-not-at-top
        save_model(self, filepath, overwrite, include_optimizer)

那些方法允许我使用:

model.save("model.h5")
new_model = keras.models.load_model('model.h5',
                                    custom_objects={
                'CustomModel': CustomModel,
                'CustomLayer': CustomLayer, #...
})

加载具有8个图层和9000个参数的模型大约需要30秒。

但是我也可以这样加载模型:

model.save_weights("weights.h5")
config = model.get_config()
new_model = CustomModel.from_config(config)
new_model.load_weights("weights.h5")

对于同一模型,此方法需要0.6秒。

第一种方法如此缓慢有没有原因?是由于自定义模型/图层的反序列化吗?

在这种情况下,首选方法是什么?

我知道TF2是相当新的东西,但是任何见识都会受到赞赏!

0 个答案:

没有答案