使用Tensorflow保存检查点

时间:2018-05-04 15:55:57

标签: python python-3.x tensorflow tensorflow-estimator

我的CNN模型有3个文件夹train_data, val_data, test_data.

当我训练我的模型时,我发现准确性可能会有所不同,有时最后一个时期并没有显示最佳准确度。例如,上一个纪元的准确率是71%,但我发现在早期时代更准确。我想保存具有更高准确度的那个纪元的检查点,然后使用该检查点在test_data上预测我的模型

我在train_data上训练了我的模型并在val_data上预测并保存了模型的检查点,如下所示:

    print("{} Saving checkpoint of model...". format(datetime.now()))
    checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
    save_path = saver.save(session, checkpoint_path)

在开始tf.Session()之前我有这一行:

saver = tf.train.Saver()

我想知道如何保存具有更高准确度的最佳纪元,然后将此检查点用于test_data

先谢谢了。

2 个答案:

答案 0 :(得分:0)

tf.train.Saver()文档描述了以下内容:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

请注意,如果将global_step传递给保护程序,则会生成包含全局步骤编号的检查点文件。我通常每X分钟保存一次检查点,然后返回并查看结果并选择适当步长值的检查点。如果你正在使用张量板,你会发现这很直观,因为你的所有图表都可以通过全局步骤显示。

https://www.tensorflow.org/api_docs/python/tf/train/Saver

答案 1 :(得分:0)

您可以使用CheckpointSaverListener

from __future__ import print_function
import tensorflow as tf
import os
from sacred import Experiment

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data

ex = Experiment('test-07-05-2018')    

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
checkpoint_path = "/tmp/checkpoints/"

class ExampleCheckpointSaverListener(CheckpointSaverListener):
    def begin(self):
       print('Starting the session.')
       self.prev_accuracy = 0
       self.acc = 0

   def after_save(self, session, global_step_value):
       print('Only keep this checkpoint if it is better than the previous one')
       self.acc = acc 
       if self.acc <  self.prev_accuracy :
            os.remove(tf.train.latest_checkpoint())
       else:
            self.prev_accuracy = self.acc

   def end(self, session, global_step_value):
       print('Done with the session.')

@ex.config
def my_config():
pass

@ex.automain
def main():
      #build the graph of vanilla multiclass logistic regression
      x = tf.placeholder(tf.float32, [None, 784])
      y = tf.placeholder(tf.float32, [None, 10]) 
      W = tf.Variable(tf.zeros([784, 10]))
      b = tf.Variable(tf.zeros([10]))
      y_pred = tf.nn.softmax(tf.matmul(x, W) + b) #
      loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=1))
      optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
      init = tf.global_variables_initializer()
      y_pred_cls = tf.argmax(y_pred, dimension=1)
      y_true_cls = tf.argmax(y, dimension=1)
      correct_prediction = tf.equal(y_pred_cls, y_true_cls)
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
      saver = tf.train.Saver()
      listener = ExampleCheckpointSaverListener()
      saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir, listeners=[listener])
      with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]) as sess:
          sess.run(init)
          for epoch in range(25):
              avg_loss = 0.
              total_batch = int(mnist.train.num_examples/100)
              # Loop over all batches
              for i in range(total_batch):
                  batch_xs, batch_ys = mnist.train.next_batch(100)
                  _, l, acc = sess.run([optimizer, loss, accuracy], feed_dict={x: batch_xs, y: batch_ys})
                  avg_loss += l / total_batch
                  saver.save(sess, checkpoint_path)