在Keras,保存模型检查点并重新加载它以便在不同数据集上进行培训的最佳方法是什么?

时间:2017-02-21 20:16:05

标签: tensorflow keras

所以,让我说我有一个训练有素的Keras神经网络称为模型。我想在进一步训练之前保存模型,这样我就可以回到那个检查点,因为我在不同的数据集上进行训练。

model.save('checkpoint_1.h5')

model.fit(data_1, labels_1)

model.save('checkpoint_2.h5')

现在我的问题就出现了。我想释放GPU内存,以便在进一步培训模型之前重新加载checkpoint_1。我目前正在做的是结束当前的张量流会话并开始新的会议。

from keras import backend as K
#End the current and start a new tensorflow session to free up gpu memory 
#to allow the next nn2 to be trained.
K.get_session().close()
sess = tf.Session()
K.set_session(sess)

然后我加载checkpoint_1并继续训练。

model = load_model('checkpoint_1.h5')

model.fit(data_2, labels_2)

有更好的方法吗?停止和启动tensorflow会话需要花费很多时间。

0 个答案:

没有答案