从Cifar-10模型获得标签预测

时间:2016-03-05 14:15:17

标签: machine-learning tensorflow

我目前正在研究tensorflow的Cifar-10教程。我想改变评估,以便我可以看到每个图像我的模型的预测是什么,以及它是真是假。我与第一部分斗争:如果我打印预测(sess.run([top_k_op]))我得到真/假值,我假设是预测是否正确。但是,如果我尝试打印实际预测(我到目前为止尝试打印logits,并打印top_k_op张量),我会得到一些数字或值,但看起来不像标签。我有什么需要改变我的代码才能真正看到模型预测的标签?

2 个答案:

答案 0 :(得分:0)

您想先评估logits。这是您的网络中的类的概率分布。具有较高值的​​张量索引将为您提供最可能的标签类。

您可以使用tf.argmax获取索引,然后使用标签中的索引将其打印出来

print labels[index]

答案 1 :(得分:0)

您可以通过查看here

找出答案

在svhn.py中,在第116行打印预测标签:print (step, int(test_labels[0]))

我使用以下方式清楚地完成了这项工作:

classification = sess.run(top_k_predict_op)
print (step, int(test_labels[0]))
print "network predicted:", classification[0], "for real label:", test_labels

如果您使用原始版本的TensorFlow CIFAR-10型号训练模型,请确保预测24 * 24张图像。