如何在python中加载图像数据集

时间:2020-05-21 07:12:07

标签: python image-processing deep-learning computer-vision pytorch

我在Windows桌面上有一个文件夹,其中包含要用于构建深度学习分类器的图像。我也有一个.csv文件,该文件具有图像编号(例如img_1035)和相应的类标签。 如何将带有标签的数据集加载到python / jupyter笔记本中? 这是指向kaggle(https://www.kaggle.com/debdoot/bdrw)上的数据集的链接。

我最好使用 PyTorch 来执行此操作,但是任何其他方式也将受到高度赞赏。

1 个答案:

答案 0 :(得分:2)

幸运的是,PyTorch具有方便的"ImageFolder" class,您可以扩展它来创建自己的数据集。

以下是使用ImageFolder的数据集的示例:

class MyDataset(torchvision.datasets.ImageFolder):

def __init__(self, train_folder_path='.', transform=None, target_transform=None):
    super().__init__(train_folder_path, transform, target_transform)

# [ Some functions omitted ]

然后使用PyTorch的“ DataLoader”加载您的集合。 这是一个训练集的示例:

training_set = MyDataset(root_path, transform)
train_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)

使用火车装载机,您可以从数据集中获取批次。然后,您可以使用这些批次来训练/验证,等等:

batch = next(iter(train_loader))
images, labels = batch

训练是一个相当复杂的过程,因此我不确定您要在这里潜水多深。我希望这是朝着正确的方向前进。