来自预训练张量流模型的混淆矩阵

时间:2021-03-31 21:36:42

标签: python tensorflow

我训练了tensorflow的mobilenet v1模型来检测我自己的类,但是我需要从生成的文件中获取模型结果的混淆矩阵,但是我找不到任何对我有用的帖子创建张量流模型的矩阵。

1 个答案:

答案 0 :(得分:3)

试试这个

import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from sklearn.metrics import confusion_matrix, classification_report
y_pred=[]
preds= model.predict etc.....
for i, p in enumerate (preds):
    p=pred[i]
    predicted_class_index=np.argmax(p)  # this the the predicted column with highest probability assuming you class_mode='categorical'
    y_pred.append(predicted_class_index) # this is a list of predictions
# now you need to create a list of the corresponding true labels
# how you get this depends on how you supplied the data to model.fit
# but somewhere you should have a list of labels - these are y_true
# then do
cm = confusion_matrix(y_true, y_pred )

如果您使用 train_gen=ImageDataGenerator.flow_from_directory 并设置 shuffle=False,则 y_true=train_gen.labels