如何在指定one_hot = True后从MNIST示例中获取整数标签?

时间:2016-11-27 03:12:23

标签: python tensorflow mnist one-hot-encoding

我一直在尝试this tutorial on Youtube1m31s处的.cls和.labels的探索),这只是一个简单的MNIST分类器模型。但由于Tensorflow中显然缺少功能,我无法完成它。

>>>from tensorflow.examples.tutorials.mnist import input_data
>>>data = input_data.read_data_sets("data/MNIST", one_hot=True)

>>>one_hot_labels = data.test.labels #mat shape=(num_images X num_classes)
>>>cls_labels = data.test.cls #mat shape=(num_images X 1)
Traceback (most recent call last):
  File "/home/file.py", line 5, in <module>
    cls_labels = data.test.cls
AttributeError: 'DataSet' object has no attribute 'cls'

在Google上搜索&#34; .cls&#34;在TF中引用,我无法找到任何与之相关的信息。

使事情有效的一个肮脏的例子:

>>>data = input_data.read_data_sets("data/MNIST", one_hot=True)
>>>data2 = input_data.read_data_sets("data/MNIST")

>>>one_hot_labels = data.test.labels #mat shape=(num_images X num_classes)
>>>cls_labels = data2.test.labels #mat shape=(num_images X 1)

我在Linux上使用Tensorflow 0.10.0,我想知道.cls选项是否已被删除?

如果是这样,是否有一种替代方法可以从one_hot向量数组中编码分类器名称数组?

由于

2 个答案:

答案 0 :(得分:1)

您的标签属于这种类型的阵列(一个很热),例如:

array([[ 0.,  0.,  0., ...,  1.,  0.,  0.],
       [ 0.,  0.,  1., ...,  0.,  0.,  0.],
       [ 0.,  1.,  0., ...,  0.,  0.,  0.],

数字1位于数组的位置,用于指示哪个标签。

要从此数据中获取整数标签,您必须使用以下内容获取索引:

data.test.cls = np.argmax(data.test.labels, axis=1)

答案 1 :(得分:0)

目前,我们对图像数据使用属性images,为类(标签)使用labels。例如,

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/MNIST", one_hot=True)
# data
images = mnist.test.images
# label
labels = mnist.test.labels

# without one-hot
mnist = input_data.read_data_sets("data/MNIST", one_hot=False)
# original data
images = mnist.test.images.reshape([-1, 28, 28])
print(images.shape)
# label
labels = mnist.test.labels
print(labels)