如何使用tf.metrics.accuracy?

时间:2017-10-17 09:43:48

标签: machine-learning tensorflow deep-learning

我想使用tf.metrics.accuracy来跟踪预测的准确性,但我不确定如何使用函数返回的update_op(acc_update_op):

accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)

我在考虑将其添加到tf.GraphKeys.UPDATE_OPS会有意义,但我不知道该怎么做。

1 个答案:

答案 0 :(得分:7)

tf.metrics.accuracy是许多流量度量TensorFlow操作之一(另一个是tf.metrics.recall)。创建后,会创建两个变量(counttotal),以便累积一个最终结果的所有传入结果。第一个返回值是计算count / total的张量。返回的第二个op是一个更新这些变量的有状态函数。在评估多个批次数据的分类器性能时,流式度量函数非常有用。一个简单的使用示例:

# building phase
with tf.name_scope("streaming"):
    accuracy, acc_update_op = tf.metrics.accuracy(labels, predictions)

test_fetches = {
    'accuracy': accuracy,
    'acc_op': acc_update_op
}

# when testing the classifier    
with tf.name_scope("streaming"):
    # clear counters for a fresh evaluation
    sess.run(tf.local_variables_initializer())

for _i in range(n_batches_in_test):
    fd = get_test_batch()
    outputs = sess.run(test_fetches, feed_dict=fd)

print("Accuracy:", outputs['accuracy'])
  

我在考虑将其添加到tf.GraphKeys.UPDATE_OPS会有意义,但我不知道该怎么做。

除非您仅使用UPDATE_OPS集合进行测试,否则这不是一个好主意。通常,该集合已经具有针对训练阶段的某些控制操作(例如移动批量标准化参数),这些操作并不意味着与验证阶段一起运行。最好将它们保存在新集合中,或者手动将这些操作添加到获取字典中。