您如何看待 TensorFlow 图像分类模型对每个类的准确度?

时间:2021-01-11 12:24:09

标签: python tensorflow image-classification

我正在学习张量流网站上的图像分类教程:https://www.tensorflow.org/tutorials/images/classification

该模型将鲜花分为 5 类之一:雏菊、蒲公英、玫瑰、向日葵和郁金香。

我可以看到总体准确度是多少,但有什么方法可以知道每个类的准确度如何?

例如,我的模型可能非常擅长预测雏菊、蒲公英、玫瑰和向日葵(接近 100% 的准确率),而在预测郁金香(接近 0%)方面表现不佳,我想我仍然会看到 80% 的总体准确率(假设类是平衡的)。我需要知道各个类别的准确率,以便将该性能与以大约 80% 的准确率预测所有类别的模型区分开来。

2 个答案:

答案 0 :(得分:1)

您只需在 sklearn 中使用分类报告即可做到这一点。

参考Documentation

答案 1 :(得分:0)

当我问这个问题时,我没有足够的 Python(或 scikit-learn)知识来回答。分类报告(由 prashant0598 建议)接近我需要的,尽管它实际上并不具有准确性。以下是分类报告的使用方法:

from sklearn.metrics import classification_report
import pandas as pd

y_pred = model.predict(val_ds)
y_pred = np.argmax(y_pred, axis=1)

y_true = np.concatenate([y for x, y in val_ds], axis=0)

cr = classification_report(y_true, y_pred, output_dict=True, target_names=class_names)
pd.DataFrame.from_dict(cr)

分类报告输出(除其他外)精度和召回率,这有帮助。

为了获得类的准确性,我们必须手动执行更多操作。这是一种方法:

from sklearn.metrics import accuracy_score

def class_accuracy(class_no):
  pred_filter = y_true==class_no
  acc = accuracy_score(y_true[pred_filter], y_pred[pred_filter])
  return acc

{class_name: class_accuracy(i) for i, class_name in enumerate(class_names)}
<块引用>

{'雏菊':0.6589147286821705,
'蒲公英':0.75,
'玫瑰':0.6,
'向日葵':0.868421052631579,
'郁金香':0.6942675159235668}

所以现在我知道了,向日葵是最容易预测的,而玫瑰则特别棘手!