Tensorflow:以最小的验证错误保存模型

时间:2016-08-31 14:52:58

标签: machine-learning tensorflow

我使用tensorflow运行了一个培训工作,并获得了验证集丢失的以下曲线。在第6000次迭代后,网络开始过度拟合。所以我想在过度拟合之前获得模型。

loss

我的训练代码如下:

train_step = ......
summary = tf.scalar_summary(l1_loss.op.name, l1_loss)
summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph)
saver = tf.train.Saver()
for i in xrange(20000):
    batch = get_next_batch(batch_size)
    sess.run(train_step, feed_dict = {x: batch.x, y:batch.y})
    if (i+1) % 100 == 0:
        saver.save(sess, "checkpoint/net", global_step = i+1)
        summary_str = sess.run(summary, feed_dict=validation_feed_dict)
        summary_writer.add_summary(summary_str, i+1)
        summary_writer.flush()

训练结束后,只保存了五个检查点(19600,19700,19800,19900,20000)。有没有办法让tensorflow根据验证错误保存检查点?

P.S。我知道tf.train.Saver有一个max_to_keep参数,原则上可以保存所有检查点。但那不是我想要的(除非它是唯一的选择)。我希望保护程序保持检查点到目前为止最小的验证损失。这可能吗?

3 个答案:

答案 0 :(得分:7)

您需要在验证集上计算分类准确度,并跟踪到目前为止看到的最佳分类精度,并且只有在验证准确性找到改进后才写入检查点。

如果数据集和/或模型很大,那么您可能必须将验证集拆分为批处理以适应内存中的计算。

本教程将准确说明如何执行您想要的操作:

https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb

它也可以作为简短视频:

https://www.youtube.com/watch?v=Lx8JUJROkh0

答案 1 :(得分:1)

这可以通过检查点来完成。在张量流1:

# you should import other functions/libs as needed to build the model

from keras.callbacks.callbacks import ModelCheckpoint

# add checkpoint to save model with lowest val loss
filepath = 'tf1_mnist_cnn.hdf5'
save_checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, \
                             save_best_only=True, save_weights_only=False, \
                             mode='auto', period=1)

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test),
          callbacks=[save_checkpoint])

Tensorflow 2:

# import other libs as needed for building model
from tensorflow.keras.callbacks import ModelCheckpoint

# add a checkpoint to save the lowest validation loss
filepath = 'tf2_mnist_model.hdf5'

checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, \
                             save_best_only=True, save_weights_only=False, \
                             mode='auto', save_frequency=1)

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test),
          callbacks=[checkpoint])

完整的演示文件在这里:https://github.com/nateGeorge/slurm_gpu_ubuntu/tree/master/demo_files

答案 2 :(得分:0)

在你的session.run中,你需要明确地要求赔偿。然后创建一个包含上次eval-loss的列表,并且只有当前的eval损失小于最后两次保存的损失时才会创建检查点。