TensorFlow:记住以前时代的重量

时间:2017-05-11 18:24:22

标签: python machine-learning tensorflow neural-network

我正在尝试使用TensorFlow。我刚刚发布了question关于它面临的问题。然而,我也有一个更理论上的问题,但有实际的后果。

在训练模型时,我发现准确性可能会有所不同。因此,可能会发生最后一个时期没有显示出最佳准确度。例如,在纪元N上,我可以具有85%的准确度,而在最后时期,准确度为65%。我想预测使用N时代的权重。

我想知道是否有一种方法可以记住时代的权重值以及以后使用的最佳准确度?

第一个也是最简单的方法是:

  1. 运行N纪元
  2. 记住最佳准确度
  3. 重新开始训练,直到我们到达一个显示与步骤2中存储的精确度相同的时期。
  4. 预测使用当前的重量
  5. 有更好的吗?

1 个答案:

答案 0 :(得分:2)

是的!您需要在培训过程中制作saver and save your session periodically。伪代码实现如下:

model = my_model()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    for epoch in range(NUM_EPOCHS):
        for batch in range(NUM_BATCHES):

            # ... train your model ...

            if batch % VALIDATION_FREQUENCY == 0:
                # Periodically test against a validation set.
                error = sess.run(model.error, feed_dict=valid_dict)
                if error < min_error:
                    min_error = error  # store your best error so far
                    saver.save(sess, MODEL_PATH)  # save the best-performing network so far

然后,当您想要针对性能最佳的迭代测试模型时:

saver.restore(sess, MODEL_PATH)
test_error = sess.run(model.error, feed_dict=test_dict)

同时查看保存和加载元图的this tutorial。根据您的使用情况,我发现加载步骤有点棘手。