如何仅从Keras提供的MNIST数据集中选择特定数字?

时间:2018-07-06 02:20:51

标签: python filter keras deep-learning mnist

我目前正在使用Keras在MNIST数据集上训练前馈神经网络。我正在使用以下格式加载数据集

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

但是我只想使用数字0和4训练我的模型,而不是全部。如何只选择2位数字?我是python的新手,可以弄清楚如何过滤mnist数据集...

4 个答案:

答案 0 :(得分:4)

Y_trainY_test为您提供图像的标签,您可以将它们与numpy.where一起使用,以过滤出带有0和4的部分标签。您所有的变量都是numpy数组,因此您可以轻松完成;

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

,您可以使用这些过滤器按索引获取数组的子集。

X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

如果您对两个以上的标签感兴趣,则语法在使用where和or时会比较麻烦。因此,您也可以使用numpy.isin创建遮罩。

train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

您可以像以前一样使用这些掩码进行布尔索引。

答案 1 :(得分:1)

当数字不连续且以0开头(keras期望连续的标签范围从0开始)时,使用Y_train = Y_train[train_mask]会引发InvalidArgumentError

解决方案(两位数)为:

train_mask = np.isin(Y_train, [2,8])
test_mask = np.isin(Y_test, [2,8])

X_train, Y_train = X_train[train_mask], np.array(Y_train[train_mask] == 8)
X_test, Y_test = X_test[test_mask], np.array(Y_test[test_mask] == 8)

答案 2 :(得分:0)

您拥有标签文件以及培训和测试:

train_images = mnist.train_images()
train_labels = mnist.train_labels()

test_images = mnist.test_images()
test_labels = mnist.test_labels()

您可以将它们与简单的列表理解结合使用来过滤数据集

zero_four_test = [test_images[key] for (key, label) in enumerate(test_labels) if int(label) == 0 or int(label) == 4]

答案 3 :(得分:0)

如果我使用Tensorflow怎么实现呢?

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./", one_hot=True)