在TensorFlow的一般培训渠道中,如何打印(可能通过TensorBoard)错误分类的示例/图像?
答案 0 :(得分:0)
假设您使用的是softmax分类器,它在N个类之间进行选择,作为网络的最后一层。伪代码可能如下所示,其中最后一层的批量大小为其第一维:
# computation graph
predictions = argmax(softmax(final_layer))
matches = predictions == argmax(labels) # if one-hot encoded
# later
batch_matches = sess.run(matches, feed_dict={...})
for image, does_match in zip(batch_images, batch_matches):
if not does_match:
cv2.imwrite('mismatched.png', image)