使用TensorFlow对手写数字进行分类而不使用验证标签

时间:2015-11-30 07:40:01

标签: python tensorflow

我一直在阅读TensorFlow教程并总体阅读机器学习。

我的理解是,使用神经网络的主要好处之一是它们能够在训练后快速对所呈现的输入进行分类。

首先,我开始逐步完成示例代码并查看训练数据的结构,并且我能够成功使用基本示例(91%准确度)来识别我使用的创建的图像(仅限数字)以下代码段:

# Training is already done using the code from the tutorial 
# Do the same for five
...
test_five_image = np.zeros((28, 28), dtype=np.uint8, order='C')
for five_coords in npg.five_coordinates:
    i = int(five_coords[0] / 28)
    j = int((five_coords[1] / 28) + 3) # By Eye Centering
    test_five_image[i][j] = 0xFF
test_five_image = np.rot90(test_five_image, 1)
Image.fromarray(np.uint8(test_five_image)).save(str(5) + '.bmp') 
...
# Images are Four, Five, 0 and 6
test_labels = input_data.dense_to_one_hot(np.array([4, 5, 0, 6], np.int32))
dataset = input_data.DataSet(test_images, test_labels)
print sess.run(accuracy, feed_dict={x: dataset.images, y_: dataset.labels})

根据以上代码生成的图像示例: Bitmap extracted from test data used.

注意: 此图像是根据点列表构建的,然后按比例缩小以适合28 * 28阵列。由于图像直接从numpy数组转换为位图,因此颜色会反转。根据MNIST文件格式,列表中的每个点都设置为0xFF,其中0为白色,255为黑色

上面的代码段输出1.0(有时0.75,具体取决于培训的准确性),并正确地将输入分类为标签。

所以我的问题是使用TensorFlow构建的神经网络简单地对输入进行分类所涉及的必要步骤是什么,例如,如果输入是'7',输出将是例如:

>>> [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]

我查看了TensorFlow文档,但我无法提出解决方案。我怀疑在教程中可能会遗漏一些东西。

由于

1 个答案:

答案 0 :(得分:2)

假设您已按照MNIST for ML Beginners Tutorial进行简单预测,请添加argmax节点,如下所示:

prediction = tf.argmax(y, 1)

然后运行它,输入您想要分类的数据:

prediction_val = sess.run(prediction, feed_dict={x: dataset.images})

prediction_val的形状为(batch_size,),并且包含批次中每张图片的最可能标签。

请注意,feed_dict仅包含x而非y_,因为prediction不依赖于y_