如何在张量流中读取Keras检查点?

时间:2019-06-02 20:54:51

标签: python tensorflow keras

我在ch_callback = ModelCheckpoint('./foo.bar')中使用了model.fit()。我得到的正是我所要求的,即'./foo.bar'文件。

  • 它实际上有什么格式?
  • 如何使用此检查点,我可以在哪里加载它?
  • 最重要的是,我可以将其转换为原生tensorflow检查点格式吗?

2 个答案:

答案 0 :(得分:4)

  1. Keras检查点为.hdf5或.h5格式。
  2. 您可以使用tf.keras.models.load_model("model.h5")来加载keras检查点。
  3. 如果要将Keras检查点转换为TF检查点,可以加载Keras模型(带有Keras后端),然后导出加载Keras模型时创建的TF图的TF检查点。
model = keras.models.load_model("model.h5")
sess = keras.backend.get_session()
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")

答案 1 :(得分:0)

tensorflow 2.x

此代码对我有用

import tensorflow as tf
from keras.models import load_model

saver = tf.train.Checkpoint()
model = load_model('model.hdf5', compile=False)
sess = tf.compat.v1.keras.backend.get_session()

save_path = saver.save("model.ckpt")