我的CNN模型有3个文件夹train_data, val_data, test_data.
当我训练我的模型时,我发现准确性可能会有所不同,有时最后一个时期并没有显示最佳准确度。例如,上一个纪元的准确率是71%,但我发现在早期时代更准确。我想保存具有更高准确度的那个纪元的检查点,然后使用该检查点在test_data
上预测我的模型
我在train_data
上训练了我的模型并在val_data
上预测并保存了模型的检查点,如下所示:
print("{} Saving checkpoint of model...". format(datetime.now()))
checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
save_path = saver.save(session, checkpoint_path)
在开始tf.Session()
之前我有这一行:
saver = tf.train.Saver()
我想知道如何保存具有更高准确度的最佳纪元,然后将此检查点用于test_data
?
先谢谢了。
答案 0 :(得分:0)
tf.train.Saver()
文档描述了以下内容:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
请注意,如果将global_step
传递给保护程序,则会生成包含全局步骤编号的检查点文件。我通常每X分钟保存一次检查点,然后返回并查看结果并选择适当步长值的检查点。如果你正在使用张量板,你会发现这很直观,因为你的所有图表都可以通过全局步骤显示。
答案 1 :(得分:0)
您可以使用CheckpointSaverListener。
from __future__ import print_function
import tensorflow as tf
import os
from sacred import Experiment
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
ex = Experiment('test-07-05-2018')
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
checkpoint_path = "/tmp/checkpoints/"
class ExampleCheckpointSaverListener(CheckpointSaverListener):
def begin(self):
print('Starting the session.')
self.prev_accuracy = 0
self.acc = 0
def after_save(self, session, global_step_value):
print('Only keep this checkpoint if it is better than the previous one')
self.acc = acc
if self.acc < self.prev_accuracy :
os.remove(tf.train.latest_checkpoint())
else:
self.prev_accuracy = self.acc
def end(self, session, global_step_value):
print('Done with the session.')
@ex.config
def my_config():
pass
@ex.automain
def main():
#build the graph of vanilla multiclass logistic regression
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y_pred = tf.nn.softmax(tf.matmul(x, W) + b) #
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
init = tf.global_variables_initializer()
y_pred_cls = tf.argmax(y_pred, dimension=1)
y_true_cls = tf.argmax(y, dimension=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]) as sess:
sess.run(init)
for epoch in range(25):
avg_loss = 0.
total_batch = int(mnist.train.num_examples/100)
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(100)
_, l, acc = sess.run([optimizer, loss, accuracy], feed_dict={x: batch_xs, y: batch_ys})
avg_loss += l / total_batch
saver.save(sess, checkpoint_path)