Mnist:获取混淆矩阵

时间:2018-11-17 14:04:47

标签: python tensorflow keras confusion-matrix

我尝试获取mnist数据集的混淆矩阵。

这是我的代码:

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.tanh),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


model.fit(x_train, y_train, epochs=1, callbacks=[history])

test_predictions = model.predict(x_test)


# Compute confusion matrix
confusion = tf.confusion_matrix(y_test, test_predictions)

问题是test_prediction是10000 x 10矩阵,而y_test是10000 x 1矩阵。实际上,神经网络并没有为每个测试样本提供输出。在这种情况下如何计算混淆矩阵?

然后如何显示混淆矩阵?我可以为此目的使用sci-kit库吗?

2 个答案:

答案 0 :(得分:2)

这可能是因为您的预测包括所有可能类别的概率。您需要选择最高概率的类,这将导致与y_test相同的维度。您可以使用numpy中的argmax()方法。它的工作原理如下:

import numpy as np
a = np.array([[0.9,0.1,0],[0.2,0.3,0.5],[0.4,0.6,0]])
np.argmax(a, axis=0)
array([0, 2, 1])

您可以使用sklearn生成混淆矩阵。您的代码将变成这样

from sklearn.metrics import confusion_matrix
import numpy as np

confusion = confusion_matrix(y_test, np.argmax(test_predictions,axis=1))

答案 1 :(得分:1)

如果您使用.predict_classes方法而不是仅使用预测方法,则会获得概率最高的类向量。

然后,您可以使用sklearn中的confusion_matrix。

SELECT REGEXP_REPLACE('SW8 4EX', ' .*', '');

此处test_predictions的形状为(10000,)。

打印结果如下:

test_predictions = model.predict_classes(x_test)

from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_true = y_test, y_pred = test_predictions)
print(cm)