我正在使用CIFAR-10数据集进行深度学习,但我想仅为水果类指定我的数据集。我们知道我们使用过:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
加载所有CIFAR-10数据集。如何仅为fruit类而不是所有数据加载数据?
答案 0 :(得分:1)
如果您不介意加载其他数据,最简单的方法是找出女性是水果标签,并执行以下操作:X_train, y_train = X_train[y_train == fruit_label], y_train[y_train == fruit_label]
,前提是您的数据存储在np.arrays中。相当于您的测试集。
如果没有,则必须修改hdf5文件或存储数据的任何位置。