来自keras.datasets的替代方法导入mnist

时间:2016-11-19 07:10:38

标签: keras

我一直在尝试一个需要导入MNIST数据的Keras示例

@if(previous-route == 'contact')
  some text
@else
  other text
@endif

它会生成错误消息,例如from keras.datasets import mnist import numpy as np (x_train, _), (x_test, _) = mnist.load_data()

它应该与我正在使用的网络环境有关。我的问题是,是否有任何函数或代码可以让我直接导入已手动下载的mnist数据集。感谢。

这是修改后的方法

Exception: URL fetch failure on https://s3.amazonaws.com/img-datasets/mnist.pkl.gz: None -- [Errno 110] Connection timed out

然后我收到以下错误消息

import sys
import pickle
import gzip
f = gzip.open('/data/mnist.pkl.gz', 'rb')
  if sys.version_info < (3,):
    data = pickle.load(f)
else:
    data = pickle.load(f, encoding='bytes')
f.close()
import numpy as np
(x_train, _), (x_test, _) = data

5 个答案:

答案 0 :(得分:4)

好吧,keras.datasets.mnist文件is really short。您可以手动模拟相同的操作,即:

  1. https://s3.amazonaws.com/img-datasets/mnist.pkl.gz
  2. 下载数据集
  3. import gzip
    f = gzip.open('mnist.pkl.gz', 'rb')
    if sys.version_info < (3,):
        data = cPickle.load(f)
    else:
        data = cPickle.load(f, encoding='bytes')
    f.close()
    (x_train, _), (x_test, _) = data
    

答案 1 :(得分:2)

您不需要其他代码,但可以告诉load_data首先加载本地版本:

  1. 您可以从另一台具有正确(代理)访问权限的计算机上下载文件https://s3.amazonaws.com/img-datasets/mnist.npz(来自https://github.com/keras-team/keras/blob/master/keras/datasets/mnist.py
  2. 将其复制到目录~/.keras/datasets/(在Linux和macOS上)
  3. 并使用正确的文件名运行load_data(path='mnist.npz')

答案 2 :(得分:1)

Keras文件位于Google Cloud Storage中的新路径中(在AWS S3中之前):

https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

使用时:

tf.keras.datasets.mnist.load_data()

您可以传递一个path参数。

load_data()将调用以参数get_file()为参数的fname,如果path是完整路径并且文件存在,则将不会下载该文件。

示例:

# gsutil cp gs://tensorflow/tf-keras-datasets/mnist.npz /tmp/data/mnist.npz
# python3
>>> import tensorflow as tf
>>> path = '/tmp/data/mnist.npz'
>>> (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(path)
>>> len(train_images)
>>> 60000

答案 3 :(得分:0)

  1. 下载文件 https://s3.amazonaws.com/img-datasets/mnist.npz
  2. mnist.npz移至.keras/datasets/目录
  3. 加载数据

    import keras
    from keras.datasets import mnist
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    

答案 4 :(得分:0)

即使指定了本地文件路径,

keras.datasets.mnist.load_data() 也会尝试从远程存储库中获取。但是,加载下载文件的最简单解决方法是使用 numpy.load(), just like they do:

path = '/tmp/data/mnist.npz'

import numpy as np

with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']