如何根据特定的类名加载CIFAR-10数据集?

时间:2016-11-18 01:29:24

标签: python machine-learning tensorflow deep-learning keras

我正在使用CIFAR-10数据集进行深度学习,但我想仅为水果类指定我的数据集。我们知道我们使用过:

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

加载所有CIFAR-10数据集。如何仅为fruit类而不是所有数据加载数据?

1 个答案:

答案 0 :(得分:1)

如果您不介意加载其他数据,最简单的方法是找出女性是水果标签,并执行以下操作:X_train, y_train = X_train[y_train == fruit_label], y_train[y_train == fruit_label],前提是您的数据存储在np.arrays中。相当于您的测试集。

如果没有,则必须修改hdf5文件或存储数据的任何位置。