TensorFlow - 使用批次数据验证准确性

时间:2018-05-01 04:43:13

标签: python tensorflow tensorflow-datasets

正如教程所说,在每一个步骤之后,我都需要使用“验证”。数据集现在验证模型的准确性,并使用' test'数据集最终测试精度。

示例代码:

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

validate_acc = sess.run(accuracy, feed_dict=validate_feed)

但我认为它对我的设备来说太大了,可能会发生OOM。

如何提供方法'准​​确性'使用一批validate_feed并获取总计' validate_acc'?

(如果我从数据集中创建一个迭代器,我如何将next_batch提供给'准​​确度'方法?)

谢谢大家的帮助!

2 个答案:

答案 0 :(得分:3)

通常,您使用类似于以下内容来测量准确度:

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

logits是您通常传递到softmax - cross entropy层的最终特征。以上计算给定批次的准确性,但不计算整个数据集的准确性。您可以改为:

correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
total_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))

执行" total_correct"对于测试集中的每个批处理并累积它们:

  correct_sum = 0
  for batch in data_set:
       batch_correct_count = sess.run(total_correct, feed_dict=validate_feed)
       correct_sum += batch_correct_count

  total_accuracy = correct_sum / data_set.size()

使用上面的公式,您可以通过批量处理数据来正确计算总体准确度。当然,这是假设for循环在数据集的互斥批次上运行。您应该避免使用数据集中的替换来禁用iid采样或采样,这通常用于随机培训。

答案 1 :(得分:3)

使用tf.metrics.acccuracy。它进行精确度的流式计算,这意味着它会为您累积所有必要的信息,并在需要时返回当前的精度估计值。

有关如何使用它的示例,请参阅this answer