import input_data MNIST tensorflow给出NoneType对象(属性错误)

时间:2016-02-09 06:06:32

标签: tensorflow attributeerror mnist

我无法正确导入mnist数据集。你能帮我弄清楚什么是错的吗? " input_data.py"正确放置并调用。

>>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

>>> trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'NoneType' object has no attribute 'train'

>>> print(mnist)
None

1 个答案:

答案 0 :(得分:0)

方法read_data_sets首先从http://yann.lecun.com/exdb/mnist/下载数据,然后将其解压缩

local_file = maybe_download(TRAIN_IMAGES,train_dir)
train_images = extract_images(local_file)

对你来说工作正常。但在此之后,返回DataSet集合的对象为null。由于它的工作正常,我无法重现错误,你可以在方法中运行调用并提供失败的地方。像这样......

>>> local_file = input_data.maybe_download('train-labels-idx1-ubyte.gz', 'MNIST_data/')
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
>>> train_labels = input_data.extract_labels(local_file, one_hot=True)
Extracting MNIST_data/train-labels-idx1-ubyte.gz
>>> local_file = input_data.maybe_download('train-images-idx3-ubyte.gz', 'MNIST_data/')
>>> train_images = input_data.extract_images(local_file)
Extracting MNIST_data/train-images-idx3-ubyte.gz
>>> local_file = input_data.maybe_download('t10k-images-idx3-ubyte.gz', 'MNIST_data/')
>>> test_images = input_data.extract_images(local_file)
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
>>> local_file = input_data.maybe_download('t10k-labels-idx1-ubyte.gz', 'MNIST_data/')
>>> test_labels = input_data.extract_labels(local_file,one_hot=True)
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
>>> VALIDATION_SIZE = 5000
>>> validation_images = train_images[:VALIDATION_SIZE]
>>> validation_labels = train_labels[:VALIDATION_SIZE]
>>> train_images = train_images[VALIDATION_SIZE:]
>>> train_labels = train_labels[VALIDATION_SIZE:]

>>> dtype = 'float32'
>>> data_set_train = input_data.DataSet(train_images, train_labels, dtype=dtype)
>>> data_set_validation = input_data.DataSet(validation_images, validation_labels, dtype=dtype)
>>> data_set_test = input_data.DataSet(test_images, test_labels, dtype=dtype)  
>>> trX = data_set_train.images
>>> print(data_set_train)
<tensorflow.examples.tutorials.mnist.input_data.DataSet object at 0x10508ff98>
>>> print(trX)
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]