我想使用tf.metrics.accuracy
来跟踪预测的准确性,但我不确定如何使用函数返回的update_op(acc_update_op
):
accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)
我在考虑将其添加到tf.GraphKeys.UPDATE_OPS
会有意义,但我不知道该怎么做。
答案 0 :(得分:7)
tf.metrics.accuracy
是许多流量度量TensorFlow操作之一(另一个是tf.metrics.recall
)。创建后,会创建两个变量(count
和total
),以便累积一个最终结果的所有传入结果。第一个返回值是计算count / total
的张量。返回的第二个op是一个更新这些变量的有状态函数。在评估多个批次数据的分类器性能时,流式度量函数非常有用。一个简单的使用示例:
# building phase
with tf.name_scope("streaming"):
accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)
test_fetches = {
'accuracy': accuracy,
'acc_op': acc_update_op
}
# when testing the classifier
with tf.name_scope("streaming"):
# clear counters for a fresh evaluation
sess.run(tf.local_variables_initializer())
for _i in range(n_batches_in_test):
fd = get_test_batch()
outputs = sess.run(test_fetches, feed_dict=fd)
print("Accuracy:", outputs['accuracy'])
我在考虑将其添加到
tf.GraphKeys.UPDATE_OPS
会有意义,但我不知道该怎么做。
除非您仅使用UPDATE_OPS集合进行测试,否则这不是一个好主意。通常,该集合已经具有针对训练阶段的某些控制操作(例如移动批量标准化参数),这些操作并不意味着与验证阶段一起运行。最好将它们保存在新集合中,或者手动将这些操作添加到获取字典中。