无法打印模型的混淆矩阵

时间:2018-04-08 00:05:41

标签: python tensorflow machine-learning mnist confusion-matrix

我实施了MLP,效果很好。但是,我在尝试打印混淆矩阵时遇到了问题。

我的模型被定义为......

logits = layers(X, weights, biases)

WHERE ...

def layers(x, weights, biases):
    layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
    out_layer = tf.matmul(layer_2, weights['out']) + biases['out']

    return out_layer

我在mnist数据集上训练模型。经过培训,我能够成功打印出模型的准确性......

pred = tf.nn.softmax(logits)

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("Accuracy: ", accuracy.eval({X:mnist.test.images, y:mnist.test.labels}))

准确度给了我90%。现在我想打印出结果的混淆矩阵。我试过以下......

confusion = tf.confusion_matrix(
         labels=mnist.test.labels, predictions=correct_prediction)

但这给了我错误......

  

ValueError:无法挤压暗淡[1],预期维度为1,得到10为' confusion_matrix / remove_squeezable_dimensions / Squeeze' (op:' Squeeze')输入形状:[10000,10]。

打印混淆矩阵的正确方法是什么?我已经挣扎了一段时间。

3 个答案:

答案 0 :(得分:1)

看起来tf.confusion_matrix的一个参数有一个10作为第二个暗淡。问题是mnist.test.labelscorrect_prediction是否为热门编码?这可以解释它。你需要那里的标签作为一个暗淡的张量。你能打印这两个张量的形状吗?

看起来correct_prediction是一个布尔张量,用于标记您的预测是否准确。对于混淆矩阵,您需要预测标签,而不是tf.argmax( pred, 1 )。同样,如果您的标签是单热编码的,您需要为混淆矩阵解码它们。请尝试confusion

这一行
confusion = tf.confusion_matrix(
     labels = tf.argmax( mnist.test.labels, 1 ),
     predictions = tf.argmax( pred, 1 ) )

为了打印混淆矩阵本身,有必要使用eval来得到最终结果:

print(confusion.eval({x:mnist.test.images, y:mnist.test.labels}))

答案 1 :(得分:1)

这对我有用:

confusion = tf.confusion_matrix(
       labels = tf.argmax( mnist.test.labels, 1 ),
       predictions = tf.argmax( y, 1 ) )
   print(confusion.eval({x:mnist.test.images, y_:mnist.test.labels})) 

[[ 960    0    2    2    1    5    7    2    1    0]
 [   0 1113    3    2    0    1    4    2   10    0]
 [   6    7  941   15   12    2   10    8   27    4]
 [   2    1   27  926    1   12    1    8   24    8]
 [   1    2    6    1  928    0    9    2    9   24]
 [   9    2    8   51   12  729   15    9   50    7]
 [  13    3   10    2    9    9  905    2    5    0]
 [   1    9   28    8   11    1    0  938    3   29]
 [   6   10    7   19    9   13    8    5  891    6]
 [   9    7    2    9   43    5    0   14   12  908]]

答案 2 :(得分:0)

对于NLTK混淆矩阵,您需要一个列表

classifier = NaiveBayesClassifier.train(trainfeats)
refsets = collections.defaultdict(set)
testsets = collections.defaultdict(set)

lsum = []
tsum = []

for i, (feats, label) in enumerate(testfeats):
  refsets[label].add(i)
  observed = classifier.classify(feats)
  testsets[observed].add(i)
  lsum.append(label)
  tsum.append(observed

print (nltk.ConfusionMatrix(lsum,tsum))