仅选择特定类别的CIFAR-10

时间:2018-01-11 18:27:37

标签: keras

我想使用CIFAR-10数据集,但我只想要青蛙,狗,猫,马和鸟类,到目前为止我使用了以下代码:

  # Plot ad hoc CIFAR10 instances
  from keras.datasets import cifar10
  from matplotlib import pyplot
  from scipy.misc import toimage
  # load data
  (X_train, y_train), (X_test, y_test) = cifar10.load_data()
  # create a grid of 3x3 images
  for i in range(0, 9):
      pyplot.subplot(330 + 1 + i)
      pyplot.imshow(toimage(X_train[i]))
  # show the plot
  pyplot.show()

cifar10.load_data()函数加载整个数据,我只能获得所需的类吗?

2 个答案:

答案 0 :(得分:1)

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

包含10个类的所有示例

选择有关类别的索引

index = np.where(y_train == 0)
X_train = X_train[indices]
y_train = y_train[indices]

给出所有第0个索引样本

答案 1 :(得分:0)

  

cifar10.load_data()函数加载整个数据,我只能获得所需的类吗?

使用load_data()提供的keras.datasets.cifar10,您无法做到这一点。此外,检查source code上的其他实用程序似乎只提供了load_data()方法。

但是,如果手动获取并加载数据集可以执行此操作。为此,您可以尝试在CIFAR10数据集上模拟keras does it(以及之前的源代码)的方式。

基于this帮助页面(您也可以从中下载数据集)似乎蛙,狗,猫,马和鸟类对应于索引6,5,3,7和2,分别。这意味着您可以在提取数据元素时使用这些索引,以便您可以选择所需的索引。

修改:另一个可以更好地为您效果的选项是丢弃您不希望通过load_data()来电的元素。根据Keras数据集page,我们看到该方法返回:

  
      
  • 2元组:

         
        
    • x_train,x_test:uint8具有形状的RGB图像数据数组(num_samples,3,32,32)。
    •   
    • y_train,y_test:uint8类别标签数组( 0-9范围内的整数),带形状(num_samples,)。
    •   
  •   

了解这一点,您可以丢弃任何没有6,5,3,7,2标签的元素,这些元素对应于您想要的类。