我想使用CIFAR-10数据集,但我只想要青蛙,狗,猫,马和鸟类,到目前为止我使用了以下代码:
# Plot ad hoc CIFAR10 instances
from keras.datasets import cifar10
from matplotlib import pyplot
from scipy.misc import toimage
# load data
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
# create a grid of 3x3 images
for i in range(0, 9):
pyplot.subplot(330 + 1 + i)
pyplot.imshow(toimage(X_train[i]))
# show the plot
pyplot.show()
cifar10.load_data()函数加载整个数据,我只能获得所需的类吗?
答案 0 :(得分:1)
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
包含10个类的所有示例
选择有关类别的索引
index = np.where(y_train == 0)
X_train = X_train[indices]
y_train = y_train[indices]
给出所有第0个索引样本
答案 1 :(得分:0)
cifar10.load_data()函数加载整个数据,我只能获得所需的类吗?
使用load_data()
提供的keras.datasets.cifar10
,您无法做到这一点。此外,检查source code上的其他实用程序似乎只提供了load_data()
方法。
但是,如果手动获取并加载数据集,可以执行此操作。为此,您可以尝试在CIFAR10数据集上模拟keras does it(以及之前的源代码)的方式。
基于this帮助页面(您也可以从中下载数据集)似乎蛙,狗,猫,马和鸟类对应于索引6,5,3,7和2,分别。这意味着您可以在提取数据元素时使用这些索引,以便您可以选择所需的索引。
修改:另一个可以更好地为您效果的选项是丢弃您不希望通过load_data()
来电的元素。根据Keras数据集page,我们看到该方法返回:
2元组:
- x_train,x_test:uint8具有形状的RGB图像数据数组(num_samples,3,32,32)。
- y_train,y_test:uint8类别标签数组( 0-9范围内的整数),带形状(num_samples,)。
了解这一点,您可以丢弃任何没有6,5,3,7,2标签的元素,这些元素对应于您想要的类。