好吧,我正在尝试使用SVM制作图像分类器。在此之前,我使用的是CNN,这就是使用PyTorch的原因。但是现在,由于我要使用SVM,因此必须使用Sci-Kit学习。
因此,首先,我需要将我的数据集拆分为训练部分和测试部分。为此,我们将使用train_test_split
。
我正在Google Colab中进行编码工作,并将图像存储在Google云端硬盘中的某些文件夹中。
这是我的数据加载器部分-
# choose the training and test datasets
train_data = datasets.ImageFolder(data+"/train", transform=transform_train)
test_data = datasets.ImageFolder(data+"/val", transform = transform_test)
#n_classes = test_data.shape[1]
n_classes = len(test_data.classes)
print(n_classes)
batch_size = 32
dataloader_train = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True, num_workers=2)
dataloader_test = torch.utils.data.DataLoader(test_data, batch_size, num_workers=2)
接下来,我将显示它并以此分割数据集-
images, labels = next(iter(dataloader_train))
imshow_numpy(images[0].numpy())
print(images.shape)
X_train, X_test, y_train, y_test = train_test_split(images, labels)
但这行----> 5 X_train, X_test, y_train, y_test = train_test_split(images, labels)
给我错误。我不知道如何解决它。有人知道吗?