以Sklearn友好格式加载MNIST数据集

时间:2017-06-27 14:47:38

标签: numpy scikit-learn mnist

我使用以下命令加载了MNIST数据集:

from dataget import data

dataset = data("mnist").get()

如何将其转换为Sklearn友好格式,即features_train,labels_train,features_test,labels_test?

我试过" np.loadtxt"但得到了这个错误:

ValueError: could not convert string to float: data

我还尝试了以下几行代码:

df = next(dataset.training_set.random_batch_dataframe_generator(10))

df

它已经返回了这个错误:

AttributeError: training_set

请,有人可以帮助我,我一直在谷歌搜索替代方法,但我仍然收到错误。谢谢!

P.S。这是我用来获取MNIST数据集的另一种方式:

dataset = fetch_mldata('MNIST original')

1 个答案:

答案 0 :(得分:0)

@ E.Z。帮我解决了问题!

features, labels = dataset.data, dataset.target

然后我使用以下代码行将它们分成训练和测试集:

msk = np.random.rand(len(features)) < 0.8
mrk = np.random.rand(len(labels)) < 0.8

features_train = features[msk]
features_test = features[~msk]
labels_train = labels[mrk]
labels_test = labels[~mrk]