正如教程所说,在每一个步骤之后,我都需要使用“验证”。数据集现在验证模型的准确性,并使用' test'数据集最终测试精度。
示例代码:
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
但我认为它对我的设备来说太大了,可能会发生OOM。
如何提供方法'准确性'使用一批validate_feed并获取总计' validate_acc'?
(如果我从数据集中创建一个迭代器,我如何将next_batch提供给'准确度'方法?)
谢谢大家的帮助!
答案 0 :(得分:3)
通常,您使用类似于以下内容来测量准确度:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
logits是您通常传递到softmax - cross entropy层的最终特征。以上计算给定批次的准确性,但不计算整个数据集的准确性。您可以改为:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_, 1))
total_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))
执行" total_correct"对于测试集中的每个批处理并累积它们:
correct_sum = 0
for batch in data_set:
batch_correct_count = sess.run(total_correct, feed_dict=validate_feed)
correct_sum += batch_correct_count
total_accuracy = correct_sum / data_set.size()
使用上面的公式,您可以通过批量处理数据来正确计算总体准确度。当然,这是假设for循环在数据集的互斥批次上运行。您应该避免使用数据集中的替换来禁用iid采样或采样,这通常用于随机培训。
答案 1 :(得分:3)
使用tf.metrics.acccuracy
。它进行精确度的流式计算,这意味着它会为您累积所有必要的信息,并在需要时返回当前的精度估计值。
有关如何使用它的示例,请参阅this answer。