
时间:2019-12-07 07:46:00

标签: python pytorch mnist


dataset = MNIST(path=data_path, download=True, shuffle=True)

if train:
   images, labels = dataset.get_train()
   images, labels = dataset.get_test()

images, labels = images[:n_examples], labels[:n_examples]
images, labels = iter(images.view(-1, 784) / 255), iter(labels)


Traceback (most recent call last):
File "C:\Users\Ati\Downloads\Compressed\bindsnet_experiments- 
master\experiments\mnist\two_layer_backprop.py", line 135, in <module>
images, labels = dataset.get_train()
AttributeError: 'TorchvisionDatasetWrapper' object has no attribute 'get_train'

我认为因为get_train()已过期,torchvision不支持 但是我测试了将mnist数据转换为图像和标签变量的不同方法 谁知道get_train()无效时我该如何更改 如果有人对此提供帮助,我将不胜感激

2 个答案:

答案 0 :(得分:1)



import os
import functools
import operator
import gzip
import struct
import array
import tempfile
    from urllib.request import urlretrieve
except ImportError:
    from urllib import urlretrieve  # py2
    from urllib.parse import urljoin
except ImportError:
    from urlparse import urljoin
import numpy

__version__ = '0.2.2'

# `datasets_url` and `temporary_dir` can be set by the user using:
# >>> mnist.datasets_url = 'http://my.mnist.url'
# >>> mnist.temporary_dir = lambda: '/tmp/mnist'
datasets_url = 'http://yann.lecun.com/exdb/mnist/'
temporary_dir = tempfile.gettempdir

class IdxDecodeError(ValueError):
    """Raised when an invalid idx file is parsed."""

def download_file(fname, target_dir=None, force=False):
    """Download fname from the datasets_url, and save it to target_dir,
    unless the file already exists, and force is False.

    fname : str
        Name of the file to download

    target_dir : str
        Directory where to store the file

    force : bool
        Force downloading the file, if it already exists

    fname : str
        Full path of the downloaded file
    target_dir = target_dir or temporary_dir()
    target_fname = os.path.join(target_dir, fname)

    if force or not os.path.isfile(target_fname):
        url = urljoin(datasets_url, fname)
        urlretrieve(url, target_fname)

    return target_fname

def parse_idx(fd):
    """Parse an IDX file, and return it as a numpy array.

    fd : file
        File descriptor of the IDX file to parse

    endian : str
        Byte order of the IDX file. See [1] for available options

    data : numpy.ndarray
        Numpy array with the dimensions and the data in the IDX file

    1. https://docs.python.org/3/library/struct.html
    DATA_TYPES = {0x08: 'B',  # unsigned byte
                  0x09: 'b',  # signed byte
                  0x0b: 'h',  # short (2 bytes)
                  0x0c: 'i',  # int (4 bytes)
                  0x0d: 'f',  # float (4 bytes)
                  0x0e: 'd'}  # double (8 bytes)

    header = fd.read(4)
    if len(header) != 4:
        raise IdxDecodeError('Invalid IDX file, '
                             'file empty or does not contain a full header.')

    zeros, data_type, num_dimensions = struct.unpack('>HBB', header)

    if zeros != 0:
        raise IdxDecodeError('Invalid IDX file, '
                             'file must start with two zero bytes. '
                             'Found 0x%02x' % zeros)

        data_type = DATA_TYPES[data_type]
    except KeyError:
        raise IdxDecodeError('Unknown data type '
                             '0x%02x in IDX file' % data_type)

    dimension_sizes = struct.unpack('>' + 'I' * num_dimensions,
                                    fd.read(4 * num_dimensions))

    data = array.array(data_type, fd.read())
    data.byteswap()  # looks like array.array reads data as little endian

    expected_items = functools.reduce(operator.mul, dimension_sizes)
    if len(data) != expected_items:
        raise IdxDecodeError('IDX file has wrong number of items. '
                             'Expected: %d. Found: %d' % (expected_items,

    return numpy.array(data).reshape(dimension_sizes)

def download_and_parse_mnist_file(fname, target_dir=None, force=False):
    """Download the IDX file named fname from the URL specified in dataset_url
    and return it as a numpy array.

    fname : str
        File name to download and parse

    target_dir : str
        Directory where to store the file

    force : bool
        Force downloading the file, if it already exists

    data : numpy.ndarray
        Numpy array with the dimensions and the data in the IDX file
    fname = download_file(fname, target_dir=target_dir, force=force)
    fopen = gzip.open if os.path.splitext(fname)[1] == '.gz' else open
    with fopen(fname, 'rb') as fd:
        return parse_idx(fd)

def train_images():
    """Return train images from Yann LeCun MNIST database as a numpy array.
    Download the file, if not already found in the temporary directory of
    the system.

    train_images : numpy.ndarray
        Numpy array with the images in the train MNIST database. The first
        dimension indexes each sample, while the other two index rows and
        columns of the image
    return download_and_parse_mnist_file('train-images-idx3-ubyte.gz')

def test_images():
    """Return test images from Yann LeCun MNIST database as a numpy array.
    Download the file, if not already found in the temporary directory of
    the system.

    test_images : numpy.ndarray
        Numpy array with the images in the train MNIST database. The first
        dimension indexes each sample, while the other two index rows and
        columns of the image
    return download_and_parse_mnist_file('t10k-images-idx3-ubyte.gz')

def train_labels():
    """Return train labels from Yann LeCun MNIST database as a numpy array.
    Download the file, if not already found in the temporary directory of
    the system.

    train_labels : numpy.ndarray
        Numpy array with the labels 0 to 9 in the train MNIST database.
    return download_and_parse_mnist_file('train-labels-idx1-ubyte.gz')

def test_labels():
    """Return test labels from Yann LeCun MNIST database as a numpy array.
    Download the file, if not already found in the temporary directory of
    the system.

    test_labels : numpy.ndarray
        Numpy array with the labels 0 to 9 in the train MNIST database.
    return download_and_parse_mnist_file('t10k-labels-idx1-ubyte.gz')



答案 1 :(得分:0)


from mnist import MNIST
dataset = MNIST('./dir_with_mnist_data_files')
images, labels = dataset.load_training()

