根据CIFAR10数据集类创建数据集时出错

时间:2019-10-17 19:35:23

标签: python machine-learning keras neural-network

我需要创建2个数据集,其中一个数据集具有CIFAR10数据集,类别为0到4,其他数据集具有5到9的类别,但是我遇到了这个错误:"boolean index did not match indexed array along dimension 1; dimension is 32 but corresponding boolean dimension is 1"

这是我到目前为止尝试过的

  import keras
  from keras.datasets import cifar10
  (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  print('x_train shape:', x_train.shape)
  x_train shape: (50000, 32, 32, 3)
  Getting error at this point
  x_train = x_train[y_train < 5]

1 个答案:

答案 0 :(得分:0)

打印出y_train.shape给出(50000,1)。要使用y_train正确索引x_train的第一个维度,必须去除第二个维度。

x_train = x_train[y_train[:, 0] < 5]

[:, 0]表示返回第一个维度上的所有元素,但仅返回第二个维度上的第一个元素。

x_train.shape现在给出(25000,32,32,3)。