我使用tf指标来计算精度和召回率,但总是得到0.0,但是当我自己计算时,我的准确度很高,是张量流量的错误还是我做错了什么。
with tf.name_scope("pointwise_accuracy"):
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
self.classification_accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
self.precision = tf.metrics.precision(self.input_y, self.logits, name="precison")[0]
输出 精度 - 0.0
答案 0 :(得分:4)
tf.metrics.precision
仅用于二进制分类问题,其参数必须全部为0
或1
,因为正如文档所述,它们将转换为{{ 1}}。如果您确实在处理二进制分类问题但想要使用logits作为参数,则可以查看tf.metrics.precision_at_thresholds
,它允许您指定预测被视为真的阈值。
但是,由于您的手动计算使用bool
,因此它看起来更像是多类别分类问题,在这种情况下,您通常不会谈论精确/召回,而只是准确性,所以您可以查看tf.metrics.accuracy
,并将tf.argmax
和tf.argmax(self.input_y, 1)
作为参数传递。