Keras + Tensorflow中的混淆矩阵

时间:2018-04-25 16:36:31

标签: python-3.x keras confusion-matrix

Q1

我训练了一个CNN模型并将其保存为model.h5。我试图检测3个物体。说," cat"," dog"和"其他"。我的测试集有300张图片,每个类别100张。前100是"猫",第2 100是"狗"而第3 100则是#34;其他"。我正在使用Keras课程ImageDataGeneratorflow_from_directory。以下是示例代码:

test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='sparse',
        shuffle=False)

现在使用

from sklearn.metrics import confusion_matrix

cnf_matrix = confusion_matrix(y_test, y_pred)

我需要y_testy_pred。我可以使用以下代码获取y_pred

probabilities = model.predict_generator(test_generator)
y_pred = np.argmax(probabilities, axis=1)
print (y_pred)

[0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 1 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 1 0 0 0 0 1 2 0 2 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1
 0 2 0 0 0 0 1 0 0 0 0 0 0 1 0 2 0 1 0 0 1 0 0 1 0 0 1 1 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1
 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 2 2 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2
 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2]

这基本上是将对象预测为0,1和2.现在我知道前100个对象(cat)为0,第二个100对象(dog)为1,第三个100对象(其他)为2.我是否创建使用numpy手动列表,其中前100个点为0,第2个100点为1,第3个100点为2以获得y_test?有没有可以做到的Keras类(创建y_test)?

Q2

如何查看错误检测到的对象。如果你查看print(y_pred),第3点是1,这是错误的预测。怎么能看到那个图像而不进入我的" test_dir"文件夹手动?

1 个答案:

答案 0 :(得分:0)

由于您未使用任何扩充和shuffle=False,因此您只需从生成器中获取图像:

imgBatch = next(test_generator)
    #it may be interesting to create the generator again if 
    #you're not sure it has output exactly all images before

使用绘图库(如Pillow(PIL)或MatplotLib)在imgBatch中绘制每个图像。

要仅绘制所需图像,请将y_testy_pred进行比较:

compare = y_test == y_pred

position = 0
while position < len(y_test):
    imgBatch = next(test_generator)
    batch = imgBatch.shape[0]

    for i in range(position,position+batch):
        if compare[i] == False:
            plot(imgBatch[i-position])

    position += batch