我使用tensorflow运行了一个培训工作,并获得了验证集丢失的以下曲线。在第6000次迭代后,网络开始过度拟合。所以我想在过度拟合之前获得模型。
我的训练代码如下:
train_step = ......
summary = tf.scalar_summary(l1_loss.op.name, l1_loss)
summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph)
saver = tf.train.Saver()
for i in xrange(20000):
batch = get_next_batch(batch_size)
sess.run(train_step, feed_dict = {x: batch.x, y:batch.y})
if (i+1) % 100 == 0:
saver.save(sess, "checkpoint/net", global_step = i+1)
summary_str = sess.run(summary, feed_dict=validation_feed_dict)
summary_writer.add_summary(summary_str, i+1)
summary_writer.flush()
训练结束后,只保存了五个检查点(19600,19700,19800,19900,20000)。有没有办法让tensorflow根据验证错误保存检查点?
P.S。我知道tf.train.Saver
有一个max_to_keep
参数,原则上可以保存所有检查点。但那不是我想要的(除非它是唯一的选择)。我希望保护程序保持检查点到目前为止最小的验证损失。这可能吗?
答案 0 :(得分:7)
您需要在验证集上计算分类准确度,并跟踪到目前为止看到的最佳分类精度,并且只有在验证准确性找到改进后才写入检查点。
如果数据集和/或模型很大,那么您可能必须将验证集拆分为批处理以适应内存中的计算。
本教程将准确说明如何执行您想要的操作:
https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb
它也可以作为简短视频:
答案 1 :(得分:1)
这可以通过检查点来完成。在张量流1:
# you should import other functions/libs as needed to build the model
from keras.callbacks.callbacks import ModelCheckpoint
# add checkpoint to save model with lowest val loss
filepath = 'tf1_mnist_cnn.hdf5'
save_checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, \
save_best_only=True, save_weights_only=False, \
mode='auto', period=1)
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test),
callbacks=[save_checkpoint])
Tensorflow 2:
# import other libs as needed for building model
from tensorflow.keras.callbacks import ModelCheckpoint
# add a checkpoint to save the lowest validation loss
filepath = 'tf2_mnist_model.hdf5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, \
save_best_only=True, save_weights_only=False, \
mode='auto', save_frequency=1)
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test),
callbacks=[checkpoint])
完整的演示文件在这里:https://github.com/nateGeorge/slurm_gpu_ubuntu/tree/master/demo_files。
答案 2 :(得分:0)
在你的session.run中,你需要明确地要求赔偿。然后创建一个包含上次eval-loss的列表,并且只有当前的eval损失小于最后两次保存的损失时才会创建检查点。