tf.metrics.accuracy与实际精度不符

时间:2019-09-19 09:31:57

标签: python tensorflow

我正在尝试使用TensorFlow(而不是Keras)重现Coursera ML课程的NN练习。

我发现使用tf.metrics.accuracy计算准确性会导致结果低于计算时的准确性。

相关代码为:

accuracy, update_op = tf.metrics.accuracy(labels=y, predictions=tf.argmax(tf.sigmoid(output), axis=1))
...
# in session:
acc = sess.run(accuracy, feed_dict={tf_x: X, tf_y: y})
sess.run(update_op, feed_dict={tf_x: X, tf_y: y})
print(f'step {step} - accuracy: {acc}')
...
# real accuracy
predictions = sess.run(tf.argmax(tf.sigmoid(output), axis=1), feed_dict={tf_x: X})
pred_y = predictions == y
print(f'Training Set Accuracy after training: {np.mean(pred_y) * 100}%')

差异甚至可以达到30%(即acc为0.5,实际精度为0.8)

我做错什么了吗?

请注意,如果我这样做:

equal = tf.equal(tf.cast(tf.argmax(tf.sigmoid(output), 1), tf.int32), y)
acc_op = tf.reduce_mean(tf.cast(equal, tf.float32))
acc = sess.run(acc_op, feed_dict={tf_x: X, tf_y: y})

我得到相同的结果... tf.metrics.accuracy是否以其他方式计算?

1 个答案:

答案 0 :(得分:1)

解决方案:首先致电sess.run(update_op, feed_dict),然后致电sess.run(accuracy)。如果要补料一个新批次,并且希望该批次的精度 ,则必须首先重置一些隐藏的变量-工作流程如下:

accuracy, update_op = tf.metrics.accuracy(tf_labels, tf_predictions, scope="my_metrics")
running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metrics")
running_vars_initializer = tf.variables_initializer(var_list=running_vars)
for i in range(num_batches):
    # explicitly initialize/reset 'total' and 'count' to 0
    sess.run(running_vars_initializer) 

    # feed labels and predictions at i-th batch to update_ops
    feed_dict={tf_labels: y[i], tf_predictions: tf.argmax(tf.sigmoid(output[i]), axis=1)}
    session.run(update_op, feed_dict=feed_dict)

    # compute and print accuracy from current 'total' and 'count'
    print('Batch {} accuracy: {}'.format(i, session.run(accuracy)))


详细信息tf.metrics.accuracy利用两个运行时变量total(正确预测的数量)和count(馈送的标签数量)在后台进行本地初始化。一旦accuracy被调用,update_op仅被更新-步骤:

  • totalcount初始化为零
  • sess.run(update_op, feed_dict)-> totalcountfeed_dict更新一次
  • sess.run(accuracy)-> accuracy使用 current totalcount来计算指标
  • sess.run(accuracy, feed_dict)-> accuracy使用 current totalcount来计算指标

最后两句话是,feed_dict实际上并没有改变accuracyaccuracytotalcount上运行,它们仅通过update_op更新。最后,

  • sess.run(accuracy, ...)不会 totalcount重置为0

这很大程度上就是为什么完全使用totalcount的原因-为了实现可扩展性;通过保持运行历史记录,它可以一次性计算太大而无法放入内存的数据指标。

最后,您的占位符逻辑看起来很简单-您将数据馈送到tf_xtf_y中,但是在tf.metrics.accuracy(...)中的任何地方都找不到它们-但这很容易解决。


参考文献/进一步阅读StackOverflow,很好的blog entry