如何解决此错误,TypeError:take():参数'index'(位置1)必须为Tensor,而不是numpy.ndarray?

时间:2019-07-09 10:56:29

标签: image machine-learning scikit-learn svm train-test-split

好吧,我正在尝试使用SVM制作图像分类器。在此之前,我使用的是CNN,这就是使用PyTorch的原因。但是现在,由于我要使用SVM,因此必须使用Sci-Kit学习。 因此,首先,我需要将我的数据集拆分为训练部分和测试部分。为此,我们将使用train_test_split。 我正在Google Colab中进行编码工作,并将图像存储在Google云端硬盘中的某些文件夹中。 这是我的数据加载器部分-

# choose the training and test datasets
train_data = datasets.ImageFolder(data+"/train", transform=transform_train)
test_data = datasets.ImageFolder(data+"/val", transform = transform_test)
#n_classes = test_data.shape[1]
n_classes = len(test_data.classes)
print(n_classes)

batch_size = 32

dataloader_train = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True, num_workers=2)
dataloader_test = torch.utils.data.DataLoader(test_data, batch_size, num_workers=2)

接下来,我将显示它并以此分割数据集-

images, labels = next(iter(dataloader_train))
imshow_numpy(images[0].numpy())
print(images.shape)
X_train, X_test, y_train, y_test = train_test_split(images, labels)

但这行----> 5 X_train, X_test, y_train, y_test = train_test_split(images, labels) 给我错误。我不知道如何解决它。有人知道吗?

0 个答案:

没有答案