仅下载某些MNIST数字

时间:2018-12-30 02:32:01

标签: python pickle mnist

我正在尝试仅下载项目的手写数字的MNIST数据库的一部分。具体来说,我只希望将数字0、1、2和3发送到神经网络。

我当前正在加载这样的数据(基于"Neural Networks and Deep Learning" by Michal Daniel Dobrzanski

import cPickle
import gzip
import numpy as np

def load_data():
    f = gzip.open('src/mnist.pkl.gz', 'rb')
    training_data, validation_data, test_data = cPickle.load(f)
    f.close()
    return (training_data, validation_data, test_data)

def load_data_wrapper():
    tr_d, va_d, te_d = load_data()
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_inputs, te_d[1])
    return (training_data, validation_data, test_data)

我尝试构造一个从load_data()创建新数据集的函数,然后发送给load_data_wrapper()(通过将tr_d, va_d, te_d = load_data()中的tr_d, va_d, te_d = digitTest()更改为load_data_wrapper()),没有运气,见下文:

def digitTest():
    tr_d, va_d, te_d = load_data()
    tr_d = list(tr_d)
    va_d = list(va_d)
    te_d = list(te_d)

    newTrD = []
    newTrD.append([])
    newTrD.append([])

    newVaD = []
    newVaD.append([])
    newVaD.append([])

    newTeD = []
    newTeD.append([])
    newTeD.append([])

    for index,label in enumerate(tr_d[1]):
        if tr_d[1][index] < 4:
            newTrD[0].append(tr_d[0][index])
            newTrD[1].append(tr_d[1][index])

    for index,label in enumerate(va_d[1]):
        if va_d[1][index] < 4:
            newVaD[0].append(va_d[0][index])
            newVaD[1].append(va_d[1][index])

    for index,label in enumerate(te_d[1]):
        if te_d[1][index] < 4:
            newTeD[0].append(te_d[0][index])
            newTeD[1].append(te_d[1][index])

    return (newTrD, newVaD, newTeD)

是否有可能实现我想要的目标?我怎样才能做到这一点?请注意,从load_data函数解析后,数据存储在元组中。

1 个答案:

答案 0 :(得分:1)

我从未使用过cPickle加载mnist数据集,我也不知道它返回什么。 阅读您的代码似乎看起来您做对了,但是如果您说它不起作用,我认为cPickle返回数据的方式或方式有些问题。

我没有python 2,所以我无法调试您的代码,但是:

我倾向于自己做这些事情:

df = pd.read_json(r'C:\path\data.json')
df.friends.apply(pd.Series)

    name            timestamp
0   Joe Jimmy       1541547573
1   Steven Peterson 1541274647

此功能将从文件中加载一组mnist标签和值。您可以在http://yann.lecun.com/exdb/mnist/获取数据集,并且必须解压缩文件。 标签为“ train-labels.idx1-ubyte”。只需将训练标签和图像或测试标签和图像的路径传递到函数中,它将加载这些值。

返回值是两个列表的元组:

def loadSet(values_path, labels_path):
    labels = []
    # labels:
    # 0000     32 bit integer  0x00000803(2051) magic number
    # 0008     32 bit integer  28               number of labels
    # 0009     unsigned byte   ??               label
    # 0010     unsigned byte   ??               label
    # ....     unsigned byte   ??               label

    with open(labels_path, 'rb') as f:
        m_number = int.from_bytes(f.read(4,), 'big')
        num_labels = int.from_bytes(f.read(4), 'big')
        for i in range(num_labels):
            labels.append(int.from_bytes(f.read(1), 'big'))

    images = []
    # images:
    # 0000     32 bit integer  0x00000803(2051) magic number
    # 0004     32 bit integer  60000            number of images
    # 0008     32 bit integer  28               number of rows
    # 0012     32 bit integer  28               number of columns
    # 0016     unsigned byte   ??               pixel
    # 0020     unsigned byte   ??               pixel
    # ....     unsigned byte   ??               pixel

    with open(values_path, 'rb') as f:
        m_number = int.from_bytes(f.read(4), 'big')
        num_images = int.from_bytes(f.read(4), 'big')
        num_rows = int.from_bytes(f.read(4), 'big')
        num_cols = int.from_bytes(f.read(4), 'big')
        for i in range(num_images):
            image = []
            for x in range(num_rows * num_cols):
                image.append(int.from_bytes(f.read(1), 'big'))
            images.append(image)

其中像素是列表本身。

如果文件不存在或者(如果可能)文件格式不正确,那么除了引发异常之外,它不会进行任何错误检查。

我也不习惯numpy,我通常在c ++和java中工作,但是您确定可以很容易地将这些值转换为numpy数组-请仔细阅读该主题。

过滤这些内容现在非常容易,您应该现在就可以将其用于digitTest。

您可能会看到,如果您使用原始的mnist数据集,则只会得到训练图像和测试图像。这里发生的是,您拿了其中一组的一部分并将其用作-我不太确定您在这里的用词-测试数据以评估培训进度。培训结束后,您可以使用“ t10k”文件来验证网络培训的水平。这里重要的是,如果您从这些t10k图像中分割测试数据,则不再使用这些数据,仅剩下的部分,因为其想法是验证尚未在网络上看到的数据的训练。