如何显示错误分类的图像?

时间:2016-12-14 05:15:02

标签: tensorflow

在TensorFlow的一般培训渠道中,如何打印(可能通过TensorBoard)错误分类的示例/图像?

1 个答案:

答案 0 :(得分:0)

假设您使用的是softmax分类器,它在N个类之间进行选择,作为网络的最后一层。伪代码可能如下所示,其中最后一层的批量大小为其第一维:

# computation graph
predictions = argmax(softmax(final_layer))
matches = predictions == argmax(labels) # if one-hot encoded

# later
batch_matches = sess.run(matches, feed_dict={...})

for image, does_match in zip(batch_images, batch_matches):
  if not does_match:
    cv2.imwrite('mismatched.png', image)