TensorFlow:导入gzip mnist数据集

时间:2018-07-16 10:15:25

标签: python tensorflow

我目前正在学习ML,并逐步完成了 link

一切都很好,但是mnist-import会发出警告,并指出已弃用了所使用的方法。

我不知道是否应该“更新”这个,这将是我的第一个问题,但我也想稍后再导入另一个数据集(它将再次是train-images.gz等)

因此,我需要的是一种从文件夹读取.gz-Datasets并将其导入的方法。我已经读过tf.data.Dataset了,但是我想我并没有真正理解它,或者这不是我所需要的。

2 个答案:

答案 0 :(得分:0)

根据https://www.tensorflow.org/tutorials上适用于TFv1.9的文档,使用Tensorflow并避免不建议使用的警告来导入MNIST(作为numpy数组)的典型方法是:

 mnist = tf.keras.datasets.mnist
 (x_train, y_train),(x_test, y_test) = mnist.load_data()
 x_train, x_test = x_train / 255.0, x_test / 255.0

因此,从现在开始,您应该避免以下情况:

  • mnist = tf.contrib.learn.datasets.load_dataset("mnist")
  • from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

无论如何,如果tf.data不是一个选项,则可能需要调整以下功能:

def extract_images(f):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth].
Args:
   f: A file object that can be passed into a gzip reader.
Returns:
   data: A 4D uint8 numpy array [index, y, x, depth].
Raises:
   ValueError: If the bytestream does not start with 2051.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
   magic = _read32(bytestream)
   num_images = _read32(bytestream)
   rows = _read32(bytestream)
   cols = _read32(bytestream)
   buf = bytestream.read(rows * cols * num_images)
   data = numpy.frombuffer(buf, dtype=numpy.uint8)
   data = data.reshape(num_images, rows, cols, 1)
return data

答案 1 :(得分:0)

我明白你的意思了,我实际上是在今天早上尝试过的(但是失败了),但是我又尝试了:这似乎行得通(这真是丑陋,需要解决方法;但是行得通)

def _read32(bytestream):
  dt = np.dtype(np.uint32).newbyteorder('>')
  return np.frombuffer(bytestream.read(4), dtype=dt)[0]

def extract_images(f):
    print('Extracting', f.name)
    with gzip.GzipFile(fileobj=f) as bytestream:
        magic = _read32(bytestream)
        if magic != 2051:
            raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                           (magic, f.name))
        num_images = _read32(bytestream)
        rows = _read32(bytestream)
        cols = _read32(bytestream)
        buf = bytestream.read(rows * cols * num_images)
        data = np.frombuffer(buf, dtype=np.uint8)
        data = data.reshape(num_images, rows, cols, 1)
        assert data.shape[3] == 1
        data = data.reshape(data.shape[0],data.shape[1] * data.shape[2])
        data = data.astype(np.float32)
        data = np.multiply(data, 1.0 / 255.0)
        return data
def extract_labels(f):
    with gzip.GzipFile(fileobj=f) as bytestream:
        magic = _read32(bytestream)
        if magic != 2049:
            raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                           (magic, f.name))
        num_items = _read32(bytestream)
        buf = bytestream.read(num_items)
        labels = np.frombuffer(buf, dtype=np.uint8)
        return labels

with gfile.Open("MNIST_data/train-images-idx3-ubyte.gz", "rb") as f:
    train_images = extract_images(f)
with gfile.Open("MNIST_data/train-labels-idx1-ubyte.gz", "rb") as f:
    train_labels = extract_labels(f)
with gfile.Open("MNIST_data/t10k-images-idx3-ubyte.gz", "rb") as f:
    test_images = extract_images(f)
with gfile.Open("MNIST_data/t10k-labels-idx1-ubyte.gz", "rb") as f:   
    test_labels = extract_labels(f)