我目前正在研究tensorflow的Cifar-10教程。我想改变评估,以便我可以看到每个图像我的模型的预测是什么,以及它是真是假。我与第一部分斗争:如果我打印预测(sess.run([top_k_op]))我得到真/假值,我假设是预测是否正确。但是,如果我尝试打印实际预测(我到目前为止尝试打印logits,并打印top_k_op张量),我会得到一些数字或值,但看起来不像标签。我有什么需要改变我的代码才能真正看到模型预测的标签?
答案 0 :(得分:0)
您想先评估logits
。这是您的网络中的类的概率分布。具有较高值的张量索引将为您提供最可能的标签类。
您可以使用tf.argmax获取索引,然后使用标签中的索引将其打印出来
print labels[index]
答案 1 :(得分:0)
您可以通过查看here
找出答案在svhn.py中,在第116行打印预测标签:print (step, int(test_labels[0]))
我使用以下方式清楚地完成了这项工作:
classification = sess.run(top_k_predict_op)
print (step, int(test_labels[0]))
print "network predicted:", classification[0], "for real label:", test_labels
如果您使用原始版本的TensorFlow CIFAR-10型号训练模型,请确保预测24 * 24张图像。